28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
34#include "ext/common.hpp"
36using dpctl::tensor::usm_ndarray;
38using ext::common::Align;
39using ext::common::CeilDiv;
41namespace statistics::sliding_window1d
44template <
typename T, u
int32_t Size>
48 using ncT =
typename std::remove_const_t<T>;
49 using SizeT =
decltype(Size);
50 static constexpr SizeT _size = Size;
53 : sbgroup(item.get_sub_group())
57 template <
typename yT>
58 T &operator[](
const yT &idx)
60 static_assert(std::is_integral_v<yT>,
61 "idx must be of an integral type");
65 template <
typename yT>
66 const T &operator[](
const yT &idx)
const
68 static_assert(std::is_integral_v<yT>,
69 "idx must be of an integral type");
75 static_assert(Size == 1,
76 "Size is not equal to 1. Use value(idx) instead");
80 const T &value()
const
82 static_assert(Size == 1,
83 "Size is not equal to 1. Use value(idx) instead");
87 template <
typename yT,
typename xT>
88 T broadcast(
const yT &y,
const xT &x)
const
90 static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
91 "y must be of an integral type");
92 static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
93 "x must be of an integral type");
95 return sycl::select_from_group(sbgroup, data[y], x);
98 template <
typename iT>
99 T broadcast(
const iT &idx)
const
101 if constexpr (Size == 1) {
102 return broadcast(0, idx);
105 return broadcast(idx / size_x(), idx % size_x());
109 template <
typename yT,
typename xT>
110 T shift_left(
const yT &y,
const xT &x)
const
112 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
113 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
115 return sycl::shift_group_left(sbgroup, data[y], x);
118 template <
typename yT,
typename xT>
119 T shift_right(
const yT &y,
const xT &x)
const
121 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
122 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
124 return sycl::shift_group_right(sbgroup, data[y], x);
127 constexpr SizeT size_y()
const
134 return sbgroup.get_max_local_range()[0];
137 SizeT total_size()
const
139 return size_x() * size_y();
149 return sbgroup.get_local_linear_id();
153 const sycl::sub_group sbgroup;
157template <
typename T, u
int32_t Size = 1>
160 using SizeT =
typename _RegistryDataStorage<T, Size>::SizeT;
164 template <
typename LaneIdT,
166 typename = std::enable_if_t<
167 std::is_invocable_r_v<bool, Condition, SizeT>>>
168 void fill_lane(
const LaneIdT &lane_id,
const T &value, Condition &&mask)
170 static_assert(std::is_integral_v<LaneIdT>,
171 "lane_id must be of an integral type");
172 if (mask(this->x())) {
173 this->data[lane_id] = value;
177 template <
typename LaneIdT>
178 void fill_lane(
const LaneIdT &lane_id,
const T &value,
const bool &mask)
180 fill_lane(lane_id, value, [mask](
auto &&) {
return mask; });
183 template <
typename LaneIdT>
184 void fill_lane(
const LaneIdT &lane_id,
const T &value)
186 fill_lane(lane_id, value,
true);
189 template <
typename Condition,
190 typename = std::enable_if_t<
191 std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
192 void fill(
const T &value, Condition &&mask)
194 for (SizeT i = 0; i < Size; ++i) {
195 fill_lane(i, value, mask(i, this->x()));
199 void fill(
const T &value)
201 fill(value, [](
auto &&,
auto &&) {
return true; });
204 template <
typename LaneIdT,
206 typename = std::enable_if_t<
207 std::is_invocable_r_v<bool, Condition, const T *const>>>
208 T *load_lane(
const LaneIdT &lane_id,
213 static_assert(std::is_integral_v<LaneIdT>,
214 "lane_id must be of an integral type");
215 this->data[lane_id] = mask(data) ? data[0] : default_v;
217 return data + this->size_x();
220 template <
typename LaneIdT>
221 T *load_lane(
const LaneIdT &laned_id,
227 laned_id, data, [mask](
auto &&) {
return mask; }, default_v);
230 template <
typename LaneIdT>
231 T *load_lane(
const LaneIdT &laned_id,
const T *
const data)
233 constexpr T default_v = 0;
234 return load_lane(laned_id, data,
true, default_v);
237 template <
typename yStrideT,
239 typename = std::enable_if_t<
240 std::is_invocable_r_v<bool, Condition, const T *const>>>
241 T *load(
const T *
const data,
242 const yStrideT &y_stride,
247 for (SizeT i = 0; i < Size; ++i) {
248 load_lane(i, it, mask, default_v);
255 template <
typename yStr
ideT>
256 T *load(
const T *
const data,
257 const yStrideT &y_stride,
262 data, y_stride, [mask](
auto &&) {
return mask; }, default_v);
265 template <
typename Condition,
266 typename = std::enable_if_t<
267 std::is_invocable_r_v<bool, Condition, const T *const>>>
268 T *load(
const T *
const data, Condition &&mask,
const T &default_v)
270 return load(data, this->size_x(), mask, default_v);
273 T *load(
const T *
const data,
const bool &mask,
const T &default_v)
276 data, [mask](
auto &&) {
return mask; }, default_v);
279 T *load(
const T *
const data)
281 constexpr T default_v = 0;
282 return load(data,
true, default_v);
285 template <
typename LaneIdT,
287 typename = std::enable_if_t<
288 std::is_invocable_r_v<bool, Condition, const T *const>>>
289 T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
291 static_assert(std::is_integral_v<LaneIdT>,
292 "lane_id must be of an integral type");
295 data[0] = this->data[lane_id];
298 return data + this->size_x();
301 template <
typename LaneIdT>
302 T *store_lane(
const LaneIdT &lane_id, T *
const data,
const bool &mask)
304 return store_lane(lane_id, data, [mask](
auto &&) {
return mask; });
307 template <
typename LaneIdT>
308 T *store_lane(
const LaneIdT &lane_id, T *
const data)
310 return store_lane(lane_id, data,
true);
313 template <
typename yStrideT,
315 typename = std::enable_if_t<
316 std::is_invocable_r_v<bool, Condition, const T *const>>>
317 T *store(T *
const data,
const yStrideT &y_stride, Condition &&condition)
320 for (SizeT i = 0; i < Size; ++i) {
321 store_lane(i, it, condition);
328 template <
typename yStr
ideT>
329 T *store(T *
const data,
const yStrideT &y_stride,
const bool &mask)
331 return store(data, y_stride, [mask](
auto &&) {
return mask; });
334 template <
typename Condition,
335 typename = std::enable_if_t<
336 std::is_invocable_r_v<bool, Condition, const T *const>>>
337 T *store(T *
const data, Condition &&condition)
339 return store(data, this->size_x(), condition);
342 T *store(T *
const data,
const bool &mask)
344 return store(data, [mask](
auto &&) {
return mask; });
347 T *store(T *
const data)
349 return store(data,
true);
353template <
typename T, u
int32_t Size>
356 using SizeT =
typename RegistryData<T, Size>::SizeT;
360 template <
typename shT>
361 void advance_left(
const shT &shift,
const T &fill_value)
363 static_assert(std::is_integral_v<shT>,
364 "shift must be of an integral type");
366 uint32_t shift_r = this->size_x() - shift;
367 for (SizeT i = 0; i < Size; ++i) {
368 this->data[i] = this->shift_left(i, shift);
370 i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
371 if (this->x() >= shift_r) {
372 this->data[i] = border;
377 void advance_left(
const T &fill_value)
379 advance_left(1, fill_value);
384 constexpr T fill_value = 0;
385 advance_left(fill_value);
389template <
typename T,
typename SizeT =
size_t>
393 using value_type = T;
394 using size_type = SizeT;
396 Span(T *
const data,
const SizeT size) : data_(data), size_(size) {}
405 return data() + size();
423template <
typename T,
typename SizeT =
size_t>
429template <
typename T,
typename SizeT =
size_t>
433 using value_type = T;
434 using size_type = SizeT;
436 PaddedSpan(T *
const data,
const SizeT size,
const SizeT pad)
441 T *padded_begin()
const
443 return this->begin() - pad();
455template <
typename T,
typename SizeT =
size_t>
457 make_padded_span(T *
const data,
const SizeT size,
const SizeT offset)
462template <
typename Results,
467void process_block(Results &results,
475 for (uint32_t i = 0; i < block_size; ++i) {
476 auto v_val = v_data.broadcast(i);
477 for (uint32_t r = 0; r < r_size; ++r) {
478 results[r] = red(results[r], op(a_data[r], v_val));
480 a_data.advance_left();
484template <
typename SizeT>
485SizeT get_global_linear_id(
const uint32_t wpi,
const sycl::nd_item<1> &item)
487 auto sbgroup = item.get_sub_group();
488 const auto sg_loc_id = sbgroup.get_local_linear_id();
490 const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
491 const SizeT
id = sg_base_id + sg_loc_id;
496template <
typename SizeT>
497uint32_t get_results_num(
const uint32_t wpi,
499 const SizeT global_id,
500 const sycl::nd_item<1> &item)
502 auto sbgroup = item.get_sub_group();
504 const auto sbg_size = sbgroup.get_max_local_range()[0];
505 const auto size_ = sycl::sub_sat(size, global_id);
506 return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
509template <uint32_t WorkPI,
516template <uint32_t WorkPI,
526 sycl::nd_range<1> nd_range,
530 nd_range, [=](sycl::nd_item<1> item) {
531 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
536 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
538 const auto *a_begin = a.begin();
539 const auto *a_end = a.end();
541 auto sbgroup = item.get_sub_group();
543 const auto chunks_count =
544 CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
546 const auto *a_ptr = &a.padded_begin()[glid];
548 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
549 return ptr >= a_begin && ptr < a_end;
553 a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
555 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
556 auto v_size = v.size();
558 for (uint32_t b = 0; b < chunks_count; ++b) {
560 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
562 uint32_t chunk_size_ =
563 std::min(v_size, SizeT(v_data.total_size()));
564 process_block(results, results_num, a_data, v_data, chunk_size_,
567 if (b != chunks_count - 1) {
568 a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
570 v_size -= v_data.total_size();
574 auto *
const out_ptr = out.begin();
579 std::min(y_start + WorkPI * results.size_x(), out.size());
581 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
582 out_ptr[y] = results[i++];
592template <uint32_t WorkPI,
599template <uint32_t WorkPI,
609 sycl::nd_range<1> nd_range,
613 nd_range, [=](sycl::nd_item<1> item) {
614 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
619 auto sbgroup = item.get_sub_group();
620 auto sg_size = sbgroup.get_max_local_range()[0];
622 const uint32_t to_read = WorkPI * sg_size + v.size();
623 const auto *a_begin = a.begin();
625 const auto *a_ptr = &a.padded_begin()[glid];
626 const auto *a_end = std::min(a_ptr + to_read, a.end());
628 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
629 return ptr >= a_begin && ptr < a_end;
633 a_data.load(a_ptr, _a_load_cond, 0);
635 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
636 auto v_size = v.size();
639 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
641 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
643 process_block(results, results_num, a_data, v_data, v_size, op,
646 auto *
const out_ptr = out.begin();
651 std::min(y_start + WorkPI * results.size_x(), out.size());
653 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
654 out_ptr[y] = results[i++];
664void validate(
const usm_ndarray &a,
665 const usm_ndarray &v,
666 const usm_ndarray &out,