33#include "utils/math_utils.hpp"
34#include <sycl/sycl.hpp>
39#include "ext/common.hpp"
41using dpctl::tensor::usm_ndarray;
43using ext::common::Align;
44using ext::common::CeilDiv;
46namespace statistics::sliding_window1d
49template <
typename T, u
int32_t Size>
53 using ncT =
typename std::remove_const_t<T>;
54 using SizeT =
decltype(Size);
55 static constexpr SizeT _size = Size;
58 : sbgroup(item.get_sub_group())
62 template <
typename yT>
63 T &operator[](
const yT &idx)
65 static_assert(std::is_integral_v<yT>,
66 "idx must be of an integral type");
70 template <
typename yT>
71 const T &operator[](
const yT &idx)
const
73 static_assert(std::is_integral_v<yT>,
74 "idx must be of an integral type");
80 static_assert(Size == 1,
81 "Size is not equal to 1. Use value(idx) instead");
85 const T &value()
const
87 static_assert(Size == 1,
88 "Size is not equal to 1. Use value(idx) instead");
92 template <
typename yT,
typename xT>
93 T broadcast(
const yT &y,
const xT &x)
const
95 static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
96 "y must be of an integral type");
97 static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
98 "x must be of an integral type");
100 return sycl::select_from_group(sbgroup, data[y], x);
103 template <
typename iT>
104 T broadcast(
const iT &idx)
const
106 if constexpr (Size == 1) {
107 return broadcast(0, idx);
110 return broadcast(idx / size_x(), idx % size_x());
114 template <
typename yT,
typename xT>
115 T shift_left(
const yT &y,
const xT &x)
const
117 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
118 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
120 return sycl::shift_group_left(sbgroup, data[y], x);
123 template <
typename yT,
typename xT>
124 T shift_right(
const yT &y,
const xT &x)
const
126 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
127 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
129 return sycl::shift_group_right(sbgroup, data[y], x);
132 constexpr SizeT size_y()
const {
return _size; }
134 SizeT size_x()
const {
return sbgroup.get_max_local_range()[0]; }
136 SizeT total_size()
const {
return size_x() * size_y(); }
138 ncT *ptr() {
return data; }
140 SizeT x()
const {
return sbgroup.get_local_linear_id(); }
143 const sycl::sub_group sbgroup;
147template <
typename T, u
int32_t Size = 1>
150 using SizeT =
typename _RegistryDataStorage<T, Size>::SizeT;
154 template <
typename LaneIdT,
156 typename = std::enable_if_t<
157 std::is_invocable_r_v<bool, Condition, SizeT>>>
158 void fill_lane(
const LaneIdT &lane_id,
const T &value, Condition &&mask)
160 static_assert(std::is_integral_v<LaneIdT>,
161 "lane_id must be of an integral type");
162 if (mask(this->x())) {
163 this->data[lane_id] = value;
167 template <
typename LaneIdT>
168 void fill_lane(
const LaneIdT &lane_id,
const T &value,
const bool &mask)
170 fill_lane(lane_id, value, [mask](
auto &&) {
return mask; });
173 template <
typename LaneIdT>
174 void fill_lane(
const LaneIdT &lane_id,
const T &value)
176 fill_lane(lane_id, value,
true);
179 template <
typename Condition,
180 typename = std::enable_if_t<
181 std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
182 void fill(
const T &value, Condition &&mask)
184 for (SizeT i = 0; i < Size; ++i) {
185 fill_lane(i, value, mask(i, this->x()));
189 void fill(
const T &value)
191 fill(value, [](
auto &&,
auto &&) {
return true; });
194 template <
typename LaneIdT,
196 typename = std::enable_if_t<
197 std::is_invocable_r_v<bool, Condition, const T *const>>>
198 T *load_lane(
const LaneIdT &lane_id,
203 static_assert(std::is_integral_v<LaneIdT>,
204 "lane_id must be of an integral type");
205 this->data[lane_id] = mask(data) ? data[0] : default_v;
207 return data + this->size_x();
210 template <
typename LaneIdT>
211 T *load_lane(
const LaneIdT &laned_id,
217 laned_id, data, [mask](
auto &&) {
return mask; }, default_v);
220 template <
typename LaneIdT>
221 T *load_lane(
const LaneIdT &laned_id,
const T *
const data)
223 constexpr T default_v = 0;
224 return load_lane(laned_id, data,
true, default_v);
227 template <
typename yStrideT,
229 typename = std::enable_if_t<
230 std::is_invocable_r_v<bool, Condition, const T *const>>>
231 T *load(
const T *
const data,
232 const yStrideT &y_stride,
237 for (SizeT i = 0; i < Size; ++i) {
238 load_lane(i, it, mask, default_v);
245 template <
typename yStr
ideT>
246 T *load(
const T *
const data,
247 const yStrideT &y_stride,
252 data, y_stride, [mask](
auto &&) {
return mask; }, default_v);
255 template <
typename Condition,
256 typename = std::enable_if_t<
257 std::is_invocable_r_v<bool, Condition, const T *const>>>
258 T *load(
const T *
const data, Condition &&mask,
const T &default_v)
260 return load(data, this->size_x(), mask, default_v);
263 T *load(
const T *
const data,
const bool &mask,
const T &default_v)
265 return load(data, [mask](
auto &&) {
return mask; }, default_v);
268 T *load(
const T *
const data)
270 constexpr T default_v = 0;
271 return load(data,
true, default_v);
274 template <
typename LaneIdT,
276 typename = std::enable_if_t<
277 std::is_invocable_r_v<bool, Condition, const T *const>>>
278 T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
280 static_assert(std::is_integral_v<LaneIdT>,
281 "lane_id must be of an integral type");
284 data[0] = this->data[lane_id];
287 return data + this->size_x();
290 template <
typename LaneIdT>
291 T *store_lane(
const LaneIdT &lane_id, T *
const data,
const bool &mask)
293 return store_lane(lane_id, data, [mask](
auto &&) {
return mask; });
296 template <
typename LaneIdT>
297 T *store_lane(
const LaneIdT &lane_id, T *
const data)
299 return store_lane(lane_id, data,
true);
302 template <
typename yStrideT,
304 typename = std::enable_if_t<
305 std::is_invocable_r_v<bool, Condition, const T *const>>>
306 T *store(T *
const data,
const yStrideT &y_stride, Condition &&condition)
309 for (SizeT i = 0; i < Size; ++i) {
310 store_lane(i, it, condition);
317 template <
typename yStr
ideT>
318 T *store(T *
const data,
const yStrideT &y_stride,
const bool &mask)
320 return store(data, y_stride, [mask](
auto &&) {
return mask; });
323 template <
typename Condition,
324 typename = std::enable_if_t<
325 std::is_invocable_r_v<bool, Condition, const T *const>>>
326 T *store(T *
const data, Condition &&condition)
328 return store(data, this->size_x(), condition);
331 T *store(T *
const data,
const bool &mask)
333 return store(data, [mask](
auto &&) {
return mask; });
336 T *store(T *
const data) {
return store(data,
true); }
339template <
typename T, u
int32_t Size>
342 using SizeT =
typename RegistryData<T, Size>::SizeT;
346 template <
typename shT>
347 void advance_left(
const shT &shift,
const T &fill_value)
349 static_assert(std::is_integral_v<shT>,
350 "shift must be of an integral type");
352 uint32_t shift_r = this->size_x() - shift;
353 for (SizeT i = 0; i < Size; ++i) {
354 this->data[i] = this->shift_left(i, shift);
356 i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
357 if (this->x() >= shift_r) {
358 this->data[i] = border;
363 void advance_left(
const T &fill_value) { advance_left(1, fill_value); }
367 constexpr T fill_value = 0;
368 advance_left(fill_value);
372template <
typename T,
typename SizeT =
size_t>
376 using value_type = T;
377 using size_type = SizeT;
379 Span(T *
const data,
const SizeT size) : data_(data), size_(size) {}
381 T *begin()
const {
return data(); }
383 T *end()
const {
return data() + size(); }
385 SizeT size()
const {
return size_; }
387 T *data()
const {
return data_; }
394template <
typename T,
typename SizeT =
size_t>
400template <
typename T,
typename SizeT =
size_t>
404 using value_type = T;
405 using size_type = SizeT;
407 PaddedSpan(T *
const data,
const SizeT size,
const SizeT pad)
412 T *padded_begin()
const {
return this->begin() - pad(); }
414 SizeT pad()
const {
return pad_; }
420template <
typename T,
typename SizeT =
size_t>
422 make_padded_span(T *
const data,
const SizeT size,
const SizeT offset)
427template <
typename Results,
432void process_block(Results &results,
440 for (uint32_t i = 0; i < block_size; ++i) {
441 auto v_val = v_data.broadcast(i);
442 for (uint32_t r = 0; r < r_size; ++r) {
443 results[r] = red(results[r], op(a_data[r], v_val));
445 a_data.advance_left();
449template <
typename SizeT>
450SizeT get_global_linear_id(
const uint32_t wpi,
const sycl::nd_item<1> &item)
452 auto sbgroup = item.get_sub_group();
453 const auto sg_loc_id = sbgroup.get_local_linear_id();
455 const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
456 const SizeT
id = sg_base_id + sg_loc_id;
461template <
typename SizeT>
462uint32_t get_results_num(
const uint32_t wpi,
464 const SizeT global_id,
465 const sycl::nd_item<1> &item)
467 auto sbgroup = item.get_sub_group();
469 const auto sbg_size = sbgroup.get_max_local_range()[0];
470 const auto size_ = sycl::sub_sat(size, global_id);
471 return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
474template <uint32_t WorkPI,
481template <uint32_t WorkPI,
491 sycl::nd_range<1> nd_range,
495 nd_range, [=](sycl::nd_item<1> item) {
496 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
501 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
503 const auto *a_begin = a.begin();
504 const auto *a_end = a.end();
506 auto sbgroup = item.get_sub_group();
508 const auto chunks_count =
509 CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
511 const auto *a_ptr = &a.padded_begin()[glid];
513 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
514 return ptr >= a_begin && ptr < a_end;
518 a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
520 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
521 auto v_size = v.size();
523 for (uint32_t b = 0; b < chunks_count; ++b) {
525 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
527 uint32_t chunk_size_ =
528 std::min(v_size, SizeT(v_data.total_size()));
529 process_block(results, results_num, a_data, v_data, chunk_size_,
532 if (b != chunks_count - 1) {
533 a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
535 v_size -= v_data.total_size();
539 auto *
const out_ptr = out.begin();
544 std::min(y_start + WorkPI * results.size_x(), out.size());
546 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
547 out_ptr[y] = results[i++];
557template <uint32_t WorkPI,
564template <uint32_t WorkPI,
574 sycl::nd_range<1> nd_range,
578 nd_range, [=](sycl::nd_item<1> item) {
579 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
584 auto sbgroup = item.get_sub_group();
585 auto sg_size = sbgroup.get_max_local_range()[0];
587 const uint32_t to_read = WorkPI * sg_size + v.size();
588 const auto *a_begin = a.begin();
590 const auto *a_ptr = &a.padded_begin()[glid];
591 const auto *a_end = std::min(a_ptr + to_read, a.end());
593 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
594 return ptr >= a_begin && ptr < a_end;
598 a_data.load(a_ptr, _a_load_cond, 0);
600 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
601 auto v_size = v.size();
604 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
606 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
608 process_block(results, results_num, a_data, v_data, v_size, op,
611 auto *
const out_ptr = out.begin();
616 std::min(y_start + WorkPI * results.size_x(), out.size());
618 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
619 out_ptr[y] = results[i++];
629void validate(
const usm_ndarray &a,
630 const usm_ndarray &v,
631 const usm_ndarray &out,