33#include "utils/math_utils.hpp"
34#include <sycl/sycl.hpp>
39#include "ext/common.hpp"
41using dpctl::tensor::usm_ndarray;
43using ext::common::Align;
44using ext::common::CeilDiv;
46namespace statistics::sliding_window1d
49template <
typename T, u
int32_t Size>
53 using ncT =
typename std::remove_const_t<T>;
54 using SizeT =
decltype(Size);
55 static constexpr SizeT _size = Size;
58 : sbgroup(item.get_sub_group())
62 template <
typename yT>
63 T &operator[](
const yT &idx)
65 static_assert(std::is_integral_v<yT>,
66 "idx must be of an integral type");
70 template <
typename yT>
71 const T &operator[](
const yT &idx)
const
73 static_assert(std::is_integral_v<yT>,
74 "idx must be of an integral type");
80 static_assert(Size == 1,
81 "Size is not equal to 1. Use value(idx) instead");
85 const T &value()
const
87 static_assert(Size == 1,
88 "Size is not equal to 1. Use value(idx) instead");
92 template <
typename yT,
typename xT>
93 T broadcast(
const yT &y,
const xT &x)
const
95 static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
96 "y must be of an integral type");
97 static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
98 "x must be of an integral type");
100 return sycl::select_from_group(sbgroup, data[y], x);
103 template <
typename iT>
104 T broadcast(
const iT &idx)
const
106 if constexpr (Size == 1) {
107 return broadcast(0, idx);
110 return broadcast(idx / size_x(), idx % size_x());
114 template <
typename yT,
typename xT>
115 T shift_left(
const yT &y,
const xT &x)
const
117 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
118 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
120 return sycl::shift_group_left(sbgroup, data[y], x);
123 template <
typename yT,
typename xT>
124 T shift_right(
const yT &y,
const xT &x)
const
126 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
127 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
129 return sycl::shift_group_right(sbgroup, data[y], x);
132 constexpr SizeT size_y()
const
139 return sbgroup.get_max_local_range()[0];
142 SizeT total_size()
const
144 return size_x() * size_y();
154 return sbgroup.get_local_linear_id();
158 const sycl::sub_group sbgroup;
162template <
typename T, u
int32_t Size = 1>
165 using SizeT =
typename _RegistryDataStorage<T, Size>::SizeT;
169 template <
typename LaneIdT,
171 typename = std::enable_if_t<
172 std::is_invocable_r_v<bool, Condition, SizeT>>>
173 void fill_lane(
const LaneIdT &lane_id,
const T &value, Condition &&mask)
175 static_assert(std::is_integral_v<LaneIdT>,
176 "lane_id must be of an integral type");
177 if (mask(this->x())) {
178 this->data[lane_id] = value;
182 template <
typename LaneIdT>
183 void fill_lane(
const LaneIdT &lane_id,
const T &value,
const bool &mask)
185 fill_lane(lane_id, value, [mask](
auto &&) {
return mask; });
188 template <
typename LaneIdT>
189 void fill_lane(
const LaneIdT &lane_id,
const T &value)
191 fill_lane(lane_id, value,
true);
194 template <
typename Condition,
195 typename = std::enable_if_t<
196 std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
197 void fill(
const T &value, Condition &&mask)
199 for (SizeT i = 0; i < Size; ++i) {
200 fill_lane(i, value, mask(i, this->x()));
204 void fill(
const T &value)
206 fill(value, [](
auto &&,
auto &&) {
return true; });
209 template <
typename LaneIdT,
211 typename = std::enable_if_t<
212 std::is_invocable_r_v<bool, Condition, const T *const>>>
213 T *load_lane(
const LaneIdT &lane_id,
218 static_assert(std::is_integral_v<LaneIdT>,
219 "lane_id must be of an integral type");
220 this->data[lane_id] = mask(data) ? data[0] : default_v;
222 return data + this->size_x();
225 template <
typename LaneIdT>
226 T *load_lane(
const LaneIdT &laned_id,
232 laned_id, data, [mask](
auto &&) {
return mask; }, default_v);
235 template <
typename LaneIdT>
236 T *load_lane(
const LaneIdT &laned_id,
const T *
const data)
238 constexpr T default_v = 0;
239 return load_lane(laned_id, data,
true, default_v);
242 template <
typename yStrideT,
244 typename = std::enable_if_t<
245 std::is_invocable_r_v<bool, Condition, const T *const>>>
246 T *load(
const T *
const data,
247 const yStrideT &y_stride,
252 for (SizeT i = 0; i < Size; ++i) {
253 load_lane(i, it, mask, default_v);
260 template <
typename yStr
ideT>
261 T *load(
const T *
const data,
262 const yStrideT &y_stride,
267 data, y_stride, [mask](
auto &&) {
return mask; }, default_v);
270 template <
typename Condition,
271 typename = std::enable_if_t<
272 std::is_invocable_r_v<bool, Condition, const T *const>>>
273 T *load(
const T *
const data, Condition &&mask,
const T &default_v)
275 return load(data, this->size_x(), mask, default_v);
278 T *load(
const T *
const data,
const bool &mask,
const T &default_v)
281 data, [mask](
auto &&) {
return mask; }, default_v);
284 T *load(
const T *
const data)
286 constexpr T default_v = 0;
287 return load(data,
true, default_v);
290 template <
typename LaneIdT,
292 typename = std::enable_if_t<
293 std::is_invocable_r_v<bool, Condition, const T *const>>>
294 T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
296 static_assert(std::is_integral_v<LaneIdT>,
297 "lane_id must be of an integral type");
300 data[0] = this->data[lane_id];
303 return data + this->size_x();
306 template <
typename LaneIdT>
307 T *store_lane(
const LaneIdT &lane_id, T *
const data,
const bool &mask)
309 return store_lane(lane_id, data, [mask](
auto &&) {
return mask; });
312 template <
typename LaneIdT>
313 T *store_lane(
const LaneIdT &lane_id, T *
const data)
315 return store_lane(lane_id, data,
true);
318 template <
typename yStrideT,
320 typename = std::enable_if_t<
321 std::is_invocable_r_v<bool, Condition, const T *const>>>
322 T *store(T *
const data,
const yStrideT &y_stride, Condition &&condition)
325 for (SizeT i = 0; i < Size; ++i) {
326 store_lane(i, it, condition);
333 template <
typename yStr
ideT>
334 T *store(T *
const data,
const yStrideT &y_stride,
const bool &mask)
336 return store(data, y_stride, [mask](
auto &&) {
return mask; });
339 template <
typename Condition,
340 typename = std::enable_if_t<
341 std::is_invocable_r_v<bool, Condition, const T *const>>>
342 T *store(T *
const data, Condition &&condition)
344 return store(data, this->size_x(), condition);
347 T *store(T *
const data,
const bool &mask)
349 return store(data, [mask](
auto &&) {
return mask; });
352 T *store(T *
const data)
354 return store(data,
true);
358template <
typename T, u
int32_t Size>
361 using SizeT =
typename RegistryData<T, Size>::SizeT;
365 template <
typename shT>
366 void advance_left(
const shT &shift,
const T &fill_value)
368 static_assert(std::is_integral_v<shT>,
369 "shift must be of an integral type");
371 uint32_t shift_r = this->size_x() - shift;
372 for (SizeT i = 0; i < Size; ++i) {
373 this->data[i] = this->shift_left(i, shift);
375 i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
376 if (this->x() >= shift_r) {
377 this->data[i] = border;
382 void advance_left(
const T &fill_value)
384 advance_left(1, fill_value);
389 constexpr T fill_value = 0;
390 advance_left(fill_value);
394template <
typename T,
typename SizeT =
size_t>
398 using value_type = T;
399 using size_type = SizeT;
401 Span(T *
const data,
const SizeT size) : data_(data), size_(size) {}
410 return data() + size();
428template <
typename T,
typename SizeT =
size_t>
434template <
typename T,
typename SizeT =
size_t>
438 using value_type = T;
439 using size_type = SizeT;
441 PaddedSpan(T *
const data,
const SizeT size,
const SizeT pad)
446 T *padded_begin()
const
448 return this->begin() - pad();
460template <
typename T,
typename SizeT =
size_t>
462 make_padded_span(T *
const data,
const SizeT size,
const SizeT offset)
467template <
typename Results,
472void process_block(Results &results,
480 for (uint32_t i = 0; i < block_size; ++i) {
481 auto v_val = v_data.broadcast(i);
482 for (uint32_t r = 0; r < r_size; ++r) {
483 results[r] = red(results[r], op(a_data[r], v_val));
485 a_data.advance_left();
489template <
typename SizeT>
490SizeT get_global_linear_id(
const uint32_t wpi,
const sycl::nd_item<1> &item)
492 auto sbgroup = item.get_sub_group();
493 const auto sg_loc_id = sbgroup.get_local_linear_id();
495 const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
496 const SizeT
id = sg_base_id + sg_loc_id;
501template <
typename SizeT>
502uint32_t get_results_num(
const uint32_t wpi,
504 const SizeT global_id,
505 const sycl::nd_item<1> &item)
507 auto sbgroup = item.get_sub_group();
509 const auto sbg_size = sbgroup.get_max_local_range()[0];
510 const auto size_ = sycl::sub_sat(size, global_id);
511 return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
514template <uint32_t WorkPI,
521template <uint32_t WorkPI,
531 sycl::nd_range<1> nd_range,
535 nd_range, [=](sycl::nd_item<1> item) {
536 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
541 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
543 const auto *a_begin = a.begin();
544 const auto *a_end = a.end();
546 auto sbgroup = item.get_sub_group();
548 const auto chunks_count =
549 CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
551 const auto *a_ptr = &a.padded_begin()[glid];
553 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
554 return ptr >= a_begin && ptr < a_end;
558 a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
560 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
561 auto v_size = v.size();
563 for (uint32_t b = 0; b < chunks_count; ++b) {
565 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
567 uint32_t chunk_size_ =
568 std::min(v_size, SizeT(v_data.total_size()));
569 process_block(results, results_num, a_data, v_data, chunk_size_,
572 if (b != chunks_count - 1) {
573 a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
575 v_size -= v_data.total_size();
579 auto *
const out_ptr = out.begin();
584 std::min(y_start + WorkPI * results.size_x(), out.size());
586 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
587 out_ptr[y] = results[i++];
597template <uint32_t WorkPI,
604template <uint32_t WorkPI,
614 sycl::nd_range<1> nd_range,
618 nd_range, [=](sycl::nd_item<1> item) {
619 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
624 auto sbgroup = item.get_sub_group();
625 auto sg_size = sbgroup.get_max_local_range()[0];
627 const uint32_t to_read = WorkPI * sg_size + v.size();
628 const auto *a_begin = a.begin();
630 const auto *a_ptr = &a.padded_begin()[glid];
631 const auto *a_end = std::min(a_ptr + to_read, a.end());
633 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
634 return ptr >= a_begin && ptr < a_end;
638 a_data.load(a_ptr, _a_load_cond, 0);
640 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
641 auto v_size = v.size();
644 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
646 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
648 process_block(results, results_num, a_data, v_data, v_size, op,
651 auto *
const out_ptr = out.begin();
656 std::min(y_start + WorkPI * results.size_x(), out.size());
658 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
659 out_ptr[y] = results[i++];
669void validate(
const usm_ndarray &a,
670 const usm_ndarray &v,
671 const usm_ndarray &out,