28#include "utils/math_utils.hpp" 
   29#include <sycl/sycl.hpp> 
   34#include "ext/common.hpp" 
   36using dpctl::tensor::usm_ndarray;
 
   38using ext::common::Align;
 
   39using ext::common::CeilDiv;
 
   41namespace statistics::sliding_window1d
 
   44template <
typename T, u
int32_t Size>
 
   48    using ncT = 
typename std::remove_const_t<T>;
 
   49    using SizeT = 
decltype(Size);
 
   50    static constexpr SizeT _size = Size;
 
   53        : sbgroup(item.get_sub_group())
 
   57    template <
typename yT>
 
   58    T &operator[](
const yT &idx)
 
   60        static_assert(std::is_integral_v<yT>,
 
   61                      "idx must be of an integral type");
 
   65    template <
typename yT>
 
   66    const T &operator[](
const yT &idx)
 const 
   68        static_assert(std::is_integral_v<yT>,
 
   69                      "idx must be of an integral type");
 
   75        static_assert(Size == 1,
 
   76                      "Size is not equal to 1. Use value(idx) instead");
 
   80    const T &value()
 const 
   82        static_assert(Size == 1,
 
   83                      "Size is not equal to 1. Use value(idx) instead");
 
   87    template <
typename yT, 
typename xT>
 
   88    T broadcast(
const yT &y, 
const xT &x)
 const 
   90        static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
 
   91                      "y must be of an integral type");
 
   92        static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
 
   93                      "x must be of an integral type");
 
   95        return sycl::select_from_group(sbgroup, data[y], x);
 
   98    template <
typename iT>
 
   99    T broadcast(
const iT &idx)
 const 
  101        if constexpr (Size == 1) {
 
  102            return broadcast(0, idx);
 
  105            return broadcast(idx / size_x(), idx % size_x());
 
  109    template <
typename yT, 
typename xT>
 
  110    T shift_left(
const yT &y, 
const xT &x)
 const 
  112        static_assert(std::is_integral_v<yT>, 
"y must be of an integral type");
 
  113        static_assert(std::is_integral_v<xT>, 
"x must be of an integral type");
 
  115        return sycl::shift_group_left(sbgroup, data[y], x);
 
  118    template <
typename yT, 
typename xT>
 
  119    T shift_right(
const yT &y, 
const xT &x)
 const 
  121        static_assert(std::is_integral_v<yT>, 
"y must be of an integral type");
 
  122        static_assert(std::is_integral_v<xT>, 
"x must be of an integral type");
 
  124        return sycl::shift_group_right(sbgroup, data[y], x);
 
  127    constexpr SizeT size_y()
 const 
  134        return sbgroup.get_max_local_range()[0];
 
  137    SizeT total_size()
 const 
  139        return size_x() * size_y();
 
  149        return sbgroup.get_local_linear_id();
 
  153    const sycl::sub_group sbgroup;
 
 
  157template <
typename T, u
int32_t Size = 1>
 
  160    using SizeT = 
typename _RegistryDataStorage<T, Size>::SizeT;
 
  164    template <
typename LaneIdT,
 
  166              typename = std::enable_if_t<
 
  167                  std::is_invocable_r_v<bool, Condition, SizeT>>>
 
  168    void fill_lane(
const LaneIdT &lane_id, 
const T &value, Condition &&mask)
 
  170        static_assert(std::is_integral_v<LaneIdT>,
 
  171                      "lane_id must be of an integral type");
 
  172        if (mask(this->x())) {
 
  173            this->data[lane_id] = value;
 
  177    template <
typename LaneIdT>
 
  178    void fill_lane(
const LaneIdT &lane_id, 
const T &value, 
const bool &mask)
 
  180        fill_lane(lane_id, value, [mask](
auto &&) { 
return mask; });
 
  183    template <
typename LaneIdT>
 
  184    void fill_lane(
const LaneIdT &lane_id, 
const T &value)
 
  186        fill_lane(lane_id, value, 
true);
 
  189    template <
typename Condition,
 
  190              typename = std::enable_if_t<
 
  191                  std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
 
  192    void fill(
const T &value, Condition &&mask)
 
  194        for (SizeT i = 0; i < Size; ++i) {
 
  195            fill_lane(i, value, mask(i, this->x()));
 
  199    void fill(
const T &value)
 
  201        fill(value, [](
auto &&, 
auto &&) { 
return true; });
 
  204    template <
typename LaneIdT,
 
  206              typename = std::enable_if_t<
 
  207                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  208    T *load_lane(
const LaneIdT &lane_id,
 
  213        static_assert(std::is_integral_v<LaneIdT>,
 
  214                      "lane_id must be of an integral type");
 
  215        this->data[lane_id] = mask(data) ? data[0] : default_v;
 
  217        return data + this->size_x();
 
  220    template <
typename LaneIdT>
 
  221    T *load_lane(
const LaneIdT &laned_id,
 
  227            laned_id, data, [mask](
auto &&) { 
return mask; }, default_v);
 
  230    template <
typename LaneIdT>
 
  231    T *load_lane(
const LaneIdT &laned_id, 
const T *
const data)
 
  233        constexpr T default_v = 0;
 
  234        return load_lane(laned_id, data, 
true, default_v);
 
  237    template <
typename yStrideT,
 
  239              typename = std::enable_if_t<
 
  240                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  241    T *load(
const T *
const data,
 
  242            const yStrideT &y_stride,
 
  247        for (SizeT i = 0; i < Size; ++i) {
 
  248            load_lane(i, it, mask, default_v);
 
  255    template <
typename yStr
ideT>
 
  256    T *load(
const T *
const data,
 
  257            const yStrideT &y_stride,
 
  262            data, y_stride, [mask](
auto &&) { 
return mask; }, default_v);
 
  265    template <
typename Condition,
 
  266              typename = std::enable_if_t<
 
  267                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  268    T *load(
const T *
const data, Condition &&mask, 
const T &default_v)
 
  270        return load(data, this->size_x(), mask, default_v);
 
  273    T *load(
const T *
const data, 
const bool &mask, 
const T &default_v)
 
  276            data, [mask](
auto &&) { 
return mask; }, default_v);
 
  279    T *load(
const T *
const data)
 
  281        constexpr T default_v = 0;
 
  282        return load(data, 
true, default_v);
 
  285    template <
typename LaneIdT,
 
  287              typename = std::enable_if_t<
 
  288                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  289    T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
 
  291        static_assert(std::is_integral_v<LaneIdT>,
 
  292                      "lane_id must be of an integral type");
 
  295            data[0] = this->data[lane_id];
 
  298        return data + this->size_x();
 
  301    template <
typename LaneIdT>
 
  302    T *store_lane(
const LaneIdT &lane_id, T *
const data, 
const bool &mask)
 
  304        return store_lane(lane_id, data, [mask](
auto &&) { 
return mask; });
 
  307    template <
typename LaneIdT>
 
  308    T *store_lane(
const LaneIdT &lane_id, T *
const data)
 
  310        return store_lane(lane_id, data, 
true);
 
  313    template <
typename yStrideT,
 
  315              typename = std::enable_if_t<
 
  316                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  317    T *store(T *
const data, 
const yStrideT &y_stride, Condition &&condition)
 
  320        for (SizeT i = 0; i < Size; ++i) {
 
  321            store_lane(i, it, condition);
 
  328    template <
typename yStr
ideT>
 
  329    T *store(T *
const data, 
const yStrideT &y_stride, 
const bool &mask)
 
  331        return store(data, y_stride, [mask](
auto &&) { 
return mask; });
 
  334    template <
typename Condition,
 
  335              typename = std::enable_if_t<
 
  336                  std::is_invocable_r_v<bool, Condition, const T *const>>>
 
  337    T *store(T *
const data, Condition &&condition)
 
  339        return store(data, this->size_x(), condition);
 
  342    T *store(T *
const data, 
const bool &mask)
 
  344        return store(data, [mask](
auto &&) { 
return mask; });
 
  347    T *store(T *
const data)
 
  349        return store(data, 
true);
 
 
  353template <
typename T, u
int32_t Size>
 
  356    using SizeT = 
typename RegistryData<T, Size>::SizeT;
 
  360    template <
typename shT>
 
  361    void advance_left(
const shT &shift, 
const T &fill_value)
 
  363        static_assert(std::is_integral_v<shT>,
 
  364                      "shift must be of an integral type");
 
  366        uint32_t shift_r = this->size_x() - shift;
 
  367        for (SizeT i = 0; i < Size; ++i) {
 
  368            this->data[i] = this->shift_left(i, shift);
 
  370                i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
 
  371            if (this->x() >= shift_r) {
 
  372                this->data[i] = border;
 
  377    void advance_left(
const T &fill_value)
 
  379        advance_left(1, fill_value);
 
  384        constexpr T fill_value = 0;
 
  385        advance_left(fill_value);
 
 
  389template <
typename T, 
typename SizeT = 
size_t>
 
  393    using value_type = T;
 
  394    using size_type = SizeT;
 
  396    Span(T *
const data, 
const SizeT size) : data_(data), size_(size) {}
 
  405        return data() + size();
 
 
  423template <
typename T, 
typename SizeT = 
size_t>
 
  429template <
typename T, 
typename SizeT = 
size_t>
 
  433    using value_type = T;
 
  434    using size_type = SizeT;
 
  436    PaddedSpan(T *
const data, 
const SizeT size, 
const SizeT pad)
 
  441    T *padded_begin()
 const 
  443        return this->begin() - pad();
 
 
  455template <
typename T, 
typename SizeT = 
size_t>
 
  457    make_padded_span(T *
const data, 
const SizeT size, 
const SizeT offset)
 
  462template <
typename Results,
 
  467void process_block(Results &results,
 
  475    for (uint32_t i = 0; i < block_size; ++i) {
 
  476        auto v_val = v_data.broadcast(i);
 
  477        for (uint32_t r = 0; r < r_size; ++r) {
 
  478            results[r] = red(results[r], op(a_data[r], v_val));
 
  480        a_data.advance_left();
 
  484template <
typename SizeT>
 
  485SizeT get_global_linear_id(
const uint32_t wpi, 
const sycl::nd_item<1> &item)
 
  487    auto sbgroup = item.get_sub_group();
 
  488    const auto sg_loc_id = sbgroup.get_local_linear_id();
 
  490    const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
 
  491    const SizeT 
id = sg_base_id + sg_loc_id;
 
  496template <
typename SizeT>
 
  497uint32_t get_results_num(
const uint32_t wpi,
 
  499                         const SizeT global_id,
 
  500                         const sycl::nd_item<1> &item)
 
  502    auto sbgroup = item.get_sub_group();
 
  504    const auto sbg_size = sbgroup.get_max_local_range()[0];
 
  505    const auto size_ = sycl::sub_sat(size, global_id);
 
  506    return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
 
  509template <uint32_t WorkPI,
 
  516template <uint32_t WorkPI,
 
  526                             sycl::nd_range<1> nd_range,
 
  530        nd_range, [=](sycl::nd_item<1> item) {
 
  531            auto glid = get_global_linear_id<SizeT>(WorkPI, item);
 
  536            auto results_num = get_results_num(WorkPI, out.size(), glid, item);
 
  538            const auto *a_begin = a.begin();
 
  539            const auto *a_end = a.end();
 
  541            auto sbgroup = item.get_sub_group();
 
  543            const auto chunks_count =
 
  544                CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
 
  546            const auto *a_ptr = &a.padded_begin()[glid];
 
  548            auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
 
  549                return ptr >= a_begin && ptr < a_end;
 
  553            a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
 
  555            const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
 
  556            auto v_size = v.size();
 
  558            for (uint32_t b = 0; b < chunks_count; ++b) {
 
  560                v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
 
  562                uint32_t chunk_size_ =
 
  563                    std::min(v_size, SizeT(v_data.total_size()));
 
  564                process_block(results, results_num, a_data, v_data, chunk_size_,
 
  567                if (b != chunks_count - 1) {
 
  568                    a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
 
  570                    v_size -= v_data.total_size();
 
  574            auto *
const out_ptr = out.begin();
 
  579                std::min(y_start + WorkPI * results.size_x(), out.size());
 
  581            for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
 
  582                out_ptr[y] = results[i++];
 
  592template <uint32_t WorkPI,
 
  599template <uint32_t WorkPI,
 
  609                                          sycl::nd_range<1> nd_range,
 
  613        nd_range, [=](sycl::nd_item<1> item) {
 
  614            auto glid = get_global_linear_id<SizeT>(WorkPI, item);
 
  619            auto sbgroup = item.get_sub_group();
 
  620            auto sg_size = sbgroup.get_max_local_range()[0];
 
  622            const uint32_t to_read = WorkPI * sg_size + v.size();
 
  623            const auto *a_begin = a.begin();
 
  625            const auto *a_ptr = &a.padded_begin()[glid];
 
  626            const auto *a_end = std::min(a_ptr + to_read, a.end());
 
  628            auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
 
  629                return ptr >= a_begin && ptr < a_end;
 
  633            a_data.load(a_ptr, _a_load_cond, 0);
 
  635            const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
 
  636            auto v_size = v.size();
 
  639            v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
 
  641            auto results_num = get_results_num(WorkPI, out.size(), glid, item);
 
  643            process_block(results, results_num, a_data, v_data, v_size, op,
 
  646            auto *
const out_ptr = out.begin();
 
  651                std::min(y_start + WorkPI * results.size_x(), out.size());
 
  653            for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
 
  654                out_ptr[y] = results[i++];
 
  664void validate(
const usm_ndarray &a,
 
  665              const usm_ndarray &v,
 
  666              const usm_ndarray &out,