31#include "utils/math_utils.hpp"
32#include <sycl/sycl.hpp>
37#include "ext/common.hpp"
39using dpctl::tensor::usm_ndarray;
41using ext::common::Align;
42using ext::common::CeilDiv;
44namespace statistics::sliding_window1d
47template <
typename T, u
int32_t Size>
51 using ncT =
typename std::remove_const_t<T>;
52 using SizeT =
decltype(Size);
53 static constexpr SizeT _size = Size;
56 : sbgroup(item.get_sub_group())
60 template <
typename yT>
61 T &operator[](
const yT &idx)
63 static_assert(std::is_integral_v<yT>,
64 "idx must be of an integral type");
68 template <
typename yT>
69 const T &operator[](
const yT &idx)
const
71 static_assert(std::is_integral_v<yT>,
72 "idx must be of an integral type");
78 static_assert(Size == 1,
79 "Size is not equal to 1. Use value(idx) instead");
83 const T &value()
const
85 static_assert(Size == 1,
86 "Size is not equal to 1. Use value(idx) instead");
90 template <
typename yT,
typename xT>
91 T broadcast(
const yT &y,
const xT &x)
const
93 static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
94 "y must be of an integral type");
95 static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
96 "x must be of an integral type");
98 return sycl::select_from_group(sbgroup, data[y], x);
101 template <
typename iT>
102 T broadcast(
const iT &idx)
const
104 if constexpr (Size == 1) {
105 return broadcast(0, idx);
108 return broadcast(idx / size_x(), idx % size_x());
112 template <
typename yT,
typename xT>
113 T shift_left(
const yT &y,
const xT &x)
const
115 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
116 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
118 return sycl::shift_group_left(sbgroup, data[y], x);
121 template <
typename yT,
typename xT>
122 T shift_right(
const yT &y,
const xT &x)
const
124 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
125 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
127 return sycl::shift_group_right(sbgroup, data[y], x);
130 constexpr SizeT size_y()
const
137 return sbgroup.get_max_local_range()[0];
140 SizeT total_size()
const
142 return size_x() * size_y();
152 return sbgroup.get_local_linear_id();
156 const sycl::sub_group sbgroup;
160template <
typename T, u
int32_t Size = 1>
163 using SizeT =
typename _RegistryDataStorage<T, Size>::SizeT;
167 template <
typename LaneIdT,
169 typename = std::enable_if_t<
170 std::is_invocable_r_v<bool, Condition, SizeT>>>
171 void fill_lane(
const LaneIdT &lane_id,
const T &value, Condition &&mask)
173 static_assert(std::is_integral_v<LaneIdT>,
174 "lane_id must be of an integral type");
175 if (mask(this->x())) {
176 this->data[lane_id] = value;
180 template <
typename LaneIdT>
181 void fill_lane(
const LaneIdT &lane_id,
const T &value,
const bool &mask)
183 fill_lane(lane_id, value, [mask](
auto &&) {
return mask; });
186 template <
typename LaneIdT>
187 void fill_lane(
const LaneIdT &lane_id,
const T &value)
189 fill_lane(lane_id, value,
true);
192 template <
typename Condition,
193 typename = std::enable_if_t<
194 std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
195 void fill(
const T &value, Condition &&mask)
197 for (SizeT i = 0; i < Size; ++i) {
198 fill_lane(i, value, mask(i, this->x()));
202 void fill(
const T &value)
204 fill(value, [](
auto &&,
auto &&) {
return true; });
207 template <
typename LaneIdT,
209 typename = std::enable_if_t<
210 std::is_invocable_r_v<bool, Condition, const T *const>>>
211 T *load_lane(
const LaneIdT &lane_id,
216 static_assert(std::is_integral_v<LaneIdT>,
217 "lane_id must be of an integral type");
218 this->data[lane_id] = mask(data) ? data[0] : default_v;
220 return data + this->size_x();
223 template <
typename LaneIdT>
224 T *load_lane(
const LaneIdT &laned_id,
230 laned_id, data, [mask](
auto &&) {
return mask; }, default_v);
233 template <
typename LaneIdT>
234 T *load_lane(
const LaneIdT &laned_id,
const T *
const data)
236 constexpr T default_v = 0;
237 return load_lane(laned_id, data,
true, default_v);
240 template <
typename yStrideT,
242 typename = std::enable_if_t<
243 std::is_invocable_r_v<bool, Condition, const T *const>>>
244 T *load(
const T *
const data,
245 const yStrideT &y_stride,
250 for (SizeT i = 0; i < Size; ++i) {
251 load_lane(i, it, mask, default_v);
258 template <
typename yStr
ideT>
259 T *load(
const T *
const data,
260 const yStrideT &y_stride,
265 data, y_stride, [mask](
auto &&) {
return mask; }, default_v);
268 template <
typename Condition,
269 typename = std::enable_if_t<
270 std::is_invocable_r_v<bool, Condition, const T *const>>>
271 T *load(
const T *
const data, Condition &&mask,
const T &default_v)
273 return load(data, this->size_x(), mask, default_v);
276 T *load(
const T *
const data,
const bool &mask,
const T &default_v)
279 data, [mask](
auto &&) {
return mask; }, default_v);
282 T *load(
const T *
const data)
284 constexpr T default_v = 0;
285 return load(data,
true, default_v);
288 template <
typename LaneIdT,
290 typename = std::enable_if_t<
291 std::is_invocable_r_v<bool, Condition, const T *const>>>
292 T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
294 static_assert(std::is_integral_v<LaneIdT>,
295 "lane_id must be of an integral type");
298 data[0] = this->data[lane_id];
301 return data + this->size_x();
304 template <
typename LaneIdT>
305 T *store_lane(
const LaneIdT &lane_id, T *
const data,
const bool &mask)
307 return store_lane(lane_id, data, [mask](
auto &&) {
return mask; });
310 template <
typename LaneIdT>
311 T *store_lane(
const LaneIdT &lane_id, T *
const data)
313 return store_lane(lane_id, data,
true);
316 template <
typename yStrideT,
318 typename = std::enable_if_t<
319 std::is_invocable_r_v<bool, Condition, const T *const>>>
320 T *store(T *
const data,
const yStrideT &y_stride, Condition &&condition)
323 for (SizeT i = 0; i < Size; ++i) {
324 store_lane(i, it, condition);
331 template <
typename yStr
ideT>
332 T *store(T *
const data,
const yStrideT &y_stride,
const bool &mask)
334 return store(data, y_stride, [mask](
auto &&) {
return mask; });
337 template <
typename Condition,
338 typename = std::enable_if_t<
339 std::is_invocable_r_v<bool, Condition, const T *const>>>
340 T *store(T *
const data, Condition &&condition)
342 return store(data, this->size_x(), condition);
345 T *store(T *
const data,
const bool &mask)
347 return store(data, [mask](
auto &&) {
return mask; });
350 T *store(T *
const data)
352 return store(data,
true);
356template <
typename T, u
int32_t Size>
359 using SizeT =
typename RegistryData<T, Size>::SizeT;
363 template <
typename shT>
364 void advance_left(
const shT &shift,
const T &fill_value)
366 static_assert(std::is_integral_v<shT>,
367 "shift must be of an integral type");
369 uint32_t shift_r = this->size_x() - shift;
370 for (SizeT i = 0; i < Size; ++i) {
371 this->data[i] = this->shift_left(i, shift);
373 i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
374 if (this->x() >= shift_r) {
375 this->data[i] = border;
380 void advance_left(
const T &fill_value)
382 advance_left(1, fill_value);
387 constexpr T fill_value = 0;
388 advance_left(fill_value);
392template <
typename T,
typename SizeT =
size_t>
396 using value_type = T;
397 using size_type = SizeT;
399 Span(T *
const data,
const SizeT size) : data_(data), size_(size) {}
408 return data() + size();
426template <
typename T,
typename SizeT =
size_t>
432template <
typename T,
typename SizeT =
size_t>
436 using value_type = T;
437 using size_type = SizeT;
439 PaddedSpan(T *
const data,
const SizeT size,
const SizeT pad)
444 T *padded_begin()
const
446 return this->begin() - pad();
458template <
typename T,
typename SizeT =
size_t>
460 make_padded_span(T *
const data,
const SizeT size,
const SizeT offset)
465template <
typename Results,
470void process_block(Results &results,
478 for (uint32_t i = 0; i < block_size; ++i) {
479 auto v_val = v_data.broadcast(i);
480 for (uint32_t r = 0; r < r_size; ++r) {
481 results[r] = red(results[r], op(a_data[r], v_val));
483 a_data.advance_left();
487template <
typename SizeT>
488SizeT get_global_linear_id(
const uint32_t wpi,
const sycl::nd_item<1> &item)
490 auto sbgroup = item.get_sub_group();
491 const auto sg_loc_id = sbgroup.get_local_linear_id();
493 const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
494 const SizeT
id = sg_base_id + sg_loc_id;
499template <
typename SizeT>
500uint32_t get_results_num(
const uint32_t wpi,
502 const SizeT global_id,
503 const sycl::nd_item<1> &item)
505 auto sbgroup = item.get_sub_group();
507 const auto sbg_size = sbgroup.get_max_local_range()[0];
508 const auto size_ = sycl::sub_sat(size, global_id);
509 return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
512template <uint32_t WorkPI,
519template <uint32_t WorkPI,
529 sycl::nd_range<1> nd_range,
533 nd_range, [=](sycl::nd_item<1> item) {
534 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
539 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
541 const auto *a_begin = a.begin();
542 const auto *a_end = a.end();
544 auto sbgroup = item.get_sub_group();
546 const auto chunks_count =
547 CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
549 const auto *a_ptr = &a.padded_begin()[glid];
551 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
552 return ptr >= a_begin && ptr < a_end;
556 a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
558 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
559 auto v_size = v.size();
561 for (uint32_t b = 0; b < chunks_count; ++b) {
563 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
565 uint32_t chunk_size_ =
566 std::min(v_size, SizeT(v_data.total_size()));
567 process_block(results, results_num, a_data, v_data, chunk_size_,
570 if (b != chunks_count - 1) {
571 a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
573 v_size -= v_data.total_size();
577 auto *
const out_ptr = out.begin();
582 std::min(y_start + WorkPI * results.size_x(), out.size());
584 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
585 out_ptr[y] = results[i++];
595template <uint32_t WorkPI,
602template <uint32_t WorkPI,
612 sycl::nd_range<1> nd_range,
616 nd_range, [=](sycl::nd_item<1> item) {
617 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
622 auto sbgroup = item.get_sub_group();
623 auto sg_size = sbgroup.get_max_local_range()[0];
625 const uint32_t to_read = WorkPI * sg_size + v.size();
626 const auto *a_begin = a.begin();
628 const auto *a_ptr = &a.padded_begin()[glid];
629 const auto *a_end = std::min(a_ptr + to_read, a.end());
631 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
632 return ptr >= a_begin && ptr < a_end;
636 a_data.load(a_ptr, _a_load_cond, 0);
638 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
639 auto v_size = v.size();
642 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
644 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
646 process_block(results, results_num, a_data, v_data, v_size, op,
649 auto *
const out_ptr = out.begin();
654 std::min(y_start + WorkPI * results.size_x(), out.size());
656 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
657 out_ptr[y] = results[i++];
667void validate(
const usm_ndarray &a,
668 const usm_ndarray &v,
669 const usm_ndarray &out,