28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
36using dpctl::tensor::usm_ndarray;
43namespace sliding_window1d
46template <
typename T, u
int32_t Size>
50 using ncT =
typename std::remove_const_t<T>;
51 using SizeT =
decltype(Size);
52 static constexpr SizeT _size = Size;
55 : sbgroup(item.get_sub_group())
59 template <
typename yT>
60 T &operator[](
const yT &idx)
62 static_assert(std::is_integral_v<yT>,
63 "idx must be of an integral type");
67 template <
typename yT>
68 const T &operator[](
const yT &idx)
const
70 static_assert(std::is_integral_v<yT>,
71 "idx must be of an integral type");
77 static_assert(Size == 1,
78 "Size is not equal to 1. Use value(idx) instead");
82 const T &value()
const
84 static_assert(Size == 1,
85 "Size is not equal to 1. Use value(idx) instead");
89 template <
typename yT,
typename xT>
90 T broadcast(
const yT &y,
const xT &x)
const
92 static_assert(std::is_integral_v<std::remove_reference_t<yT>>,
93 "y must be of an integral type");
94 static_assert(std::is_integral_v<std::remove_reference_t<xT>>,
95 "x must be of an integral type");
97 return sycl::select_from_group(sbgroup, data[y], x);
100 template <
typename iT>
101 T broadcast(
const iT &idx)
const
103 if constexpr (Size == 1) {
104 return broadcast(0, idx);
107 return broadcast(idx / size_x(), idx % size_x());
111 template <
typename yT,
typename xT>
112 T shift_left(
const yT &y,
const xT &x)
const
114 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
115 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
117 return sycl::shift_group_left(sbgroup, data[y], x);
120 template <
typename yT,
typename xT>
121 T shift_right(
const yT &y,
const xT &x)
const
123 static_assert(std::is_integral_v<yT>,
"y must be of an integral type");
124 static_assert(std::is_integral_v<xT>,
"x must be of an integral type");
126 return sycl::shift_group_right(sbgroup, data[y], x);
129 constexpr SizeT size_y()
const
136 return sbgroup.get_max_local_range()[0];
139 SizeT total_size()
const
141 return size_x() * size_y();
151 return sbgroup.get_local_linear_id();
155 const sycl::sub_group sbgroup;
159template <
typename T, u
int32_t Size = 1>
162 using SizeT =
typename _RegistryDataStorage<T, Size>::SizeT;
166 template <
typename LaneIdT,
168 typename = std::enable_if_t<
169 std::is_invocable_r_v<bool, Condition, SizeT>>>
170 void fill_lane(
const LaneIdT &lane_id,
const T &value, Condition &&mask)
172 static_assert(std::is_integral_v<LaneIdT>,
173 "lane_id must be of an integral type");
174 if (mask(this->x())) {
175 this->data[lane_id] = value;
179 template <
typename LaneIdT>
180 void fill_lane(
const LaneIdT &lane_id,
const T &value,
const bool &mask)
182 fill_lane(lane_id, value, [mask](
auto &&) {
return mask; });
185 template <
typename LaneIdT>
186 void fill_lane(
const LaneIdT &lane_id,
const T &value)
188 fill_lane(lane_id, value,
true);
191 template <
typename Condition,
192 typename = std::enable_if_t<
193 std::is_invocable_r_v<bool, Condition, SizeT, SizeT>>>
194 void fill(
const T &value, Condition &&mask)
196 for (SizeT i = 0; i < Size; ++i) {
197 fill_lane(i, value, mask(i, this->x()));
201 void fill(
const T &value)
203 fill(value, [](
auto &&,
auto &&) {
return true; });
206 template <
typename LaneIdT,
208 typename = std::enable_if_t<
209 std::is_invocable_r_v<bool, Condition, const T *const>>>
210 T *load_lane(
const LaneIdT &lane_id,
215 static_assert(std::is_integral_v<LaneIdT>,
216 "lane_id must be of an integral type");
217 this->data[lane_id] = mask(data) ? data[0] : default_v;
219 return data + this->size_x();
222 template <
typename LaneIdT>
223 T *load_lane(
const LaneIdT &laned_id,
229 laned_id, data, [mask](
auto &&) {
return mask; }, default_v);
232 template <
typename LaneIdT>
233 T *load_lane(
const LaneIdT &laned_id,
const T *
const data)
235 constexpr T default_v = 0;
236 return load_lane(laned_id, data,
true, default_v);
239 template <
typename yStrideT,
241 typename = std::enable_if_t<
242 std::is_invocable_r_v<bool, Condition, const T *const>>>
243 T *load(
const T *
const data,
244 const yStrideT &y_stride,
249 for (SizeT i = 0; i < Size; ++i) {
250 load_lane(i, it, mask, default_v);
257 template <
typename yStr
ideT>
258 T *load(
const T *
const data,
259 const yStrideT &y_stride,
264 data, y_stride, [mask](
auto &&) {
return mask; }, default_v);
267 template <
typename Condition,
268 typename = std::enable_if_t<
269 std::is_invocable_r_v<bool, Condition, const T *const>>>
270 T *load(
const T *
const data, Condition &&mask,
const T &default_v)
272 return load(data, this->size_x(), mask, default_v);
275 T *load(
const T *
const data,
const bool &mask,
const T &default_v)
278 data, [mask](
auto &&) {
return mask; }, default_v);
281 T *load(
const T *
const data)
283 constexpr T default_v = 0;
284 return load(data,
true, default_v);
287 template <
typename LaneIdT,
289 typename = std::enable_if_t<
290 std::is_invocable_r_v<bool, Condition, const T *const>>>
291 T *store_lane(
const LaneIdT &lane_id, T *
const data, Condition &&mask)
293 static_assert(std::is_integral_v<LaneIdT>,
294 "lane_id must be of an integral type");
297 data[0] = this->data[lane_id];
300 return data + this->size_x();
303 template <
typename LaneIdT>
304 T *store_lane(
const LaneIdT &lane_id, T *
const data,
const bool &mask)
306 return store_lane(lane_id, data, [mask](
auto &&) {
return mask; });
309 template <
typename LaneIdT>
310 T *store_lane(
const LaneIdT &lane_id, T *
const data)
312 return store_lane(lane_id, data,
true);
315 template <
typename yStrideT,
317 typename = std::enable_if_t<
318 std::is_invocable_r_v<bool, Condition, const T *const>>>
319 T *store(T *
const data,
const yStrideT &y_stride, Condition &&condition)
322 for (SizeT i = 0; i < Size; ++i) {
323 store_lane(i, it, condition);
330 template <
typename yStr
ideT>
331 T *store(T *
const data,
const yStrideT &y_stride,
const bool &mask)
333 return store(data, y_stride, [mask](
auto &&) {
return mask; });
336 template <
typename Condition,
337 typename = std::enable_if_t<
338 std::is_invocable_r_v<bool, Condition, const T *const>>>
339 T *store(T *
const data, Condition &&condition)
341 return store(data, this->size_x(), condition);
344 T *store(T *
const data,
const bool &mask)
346 return store(data, [mask](
auto &&) {
return mask; });
349 T *store(T *
const data)
351 return store(data,
true);
355template <
typename T, u
int32_t Size>
358 using SizeT =
typename RegistryData<T, Size>::SizeT;
362 template <
typename shT>
363 void advance_left(
const shT &shift,
const T &fill_value)
365 static_assert(std::is_integral_v<shT>,
366 "shift must be of an integral type");
368 uint32_t shift_r = this->size_x() - shift;
369 for (SizeT i = 0; i < Size; ++i) {
370 this->data[i] = this->shift_left(i, shift);
372 i < Size - 1 ? this->shift_right(i + 1, shift_r) : fill_value;
373 if (this->x() >= shift_r) {
374 this->data[i] = border;
379 void advance_left(
const T &fill_value)
381 advance_left(1, fill_value);
386 constexpr T fill_value = 0;
387 advance_left(fill_value);
391template <
typename T,
typename SizeT =
size_t>
395 using value_type = T;
396 using size_type = SizeT;
398 Span(T *
const data,
const SizeT size) : data_(data), size_(size) {}
407 return data() + size();
425template <
typename T,
typename SizeT =
size_t>
431template <
typename T,
typename SizeT =
size_t>
435 using value_type = T;
436 using size_type = SizeT;
438 PaddedSpan(T *
const data,
const SizeT size,
const SizeT pad)
443 T *padded_begin()
const
445 return this->begin() - pad();
457template <
typename T,
typename SizeT =
size_t>
459 make_padded_span(T *
const data,
const SizeT size,
const SizeT offset)
464template <
typename Results,
469void process_block(Results &results,
477 for (uint32_t i = 0; i < block_size; ++i) {
478 auto v_val = v_data.broadcast(i);
479 for (uint32_t r = 0; r < r_size; ++r) {
480 results[r] = red(results[r], op(a_data[r], v_val));
482 a_data.advance_left();
486template <
typename SizeT>
487SizeT get_global_linear_id(
const uint32_t wpi,
const sycl::nd_item<1> &item)
489 auto sbgroup = item.get_sub_group();
490 const auto sg_loc_id = sbgroup.get_local_linear_id();
492 const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
493 const SizeT
id = sg_base_id + sg_loc_id;
498template <
typename SizeT>
499uint32_t get_results_num(
const uint32_t wpi,
501 const SizeT global_id,
502 const sycl::nd_item<1> &item)
504 auto sbgroup = item.get_sub_group();
506 const auto sbg_size = sbgroup.get_max_local_range()[0];
507 const auto size_ = sycl::sub_sat(size, global_id);
508 return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
511template <uint32_t WorkPI,
518template <uint32_t WorkPI,
528 sycl::nd_range<1> nd_range,
532 nd_range, [=](sycl::nd_item<1> item) {
533 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
538 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
540 const auto *a_begin = a.begin();
541 const auto *a_end = a.end();
543 auto sbgroup = item.get_sub_group();
545 const auto chunks_count =
546 CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
548 const auto *a_ptr = &a.padded_begin()[glid];
550 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
551 return ptr >= a_begin && ptr < a_end;
555 a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
557 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
558 auto v_size = v.size();
560 for (uint32_t b = 0; b < chunks_count; ++b) {
562 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
564 uint32_t chunk_size_ =
565 std::min(v_size, SizeT(v_data.total_size()));
566 process_block(results, results_num, a_data, v_data, chunk_size_,
569 if (b != chunks_count - 1) {
570 a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
572 v_size -= v_data.total_size();
576 auto *
const out_ptr = out.begin();
581 std::min(y_start + WorkPI * results.size_x(), out.size());
583 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
584 out_ptr[y] = results[i++];
594template <uint32_t WorkPI,
601template <uint32_t WorkPI,
611 sycl::nd_range<1> nd_range,
615 nd_range, [=](sycl::nd_item<1> item) {
616 auto glid = get_global_linear_id<SizeT>(WorkPI, item);
621 auto sbgroup = item.get_sub_group();
622 auto sg_size = sbgroup.get_max_local_range()[0];
624 const uint32_t to_read = WorkPI * sg_size + v.size();
625 const auto *a_begin = a.begin();
627 const auto *a_ptr = &a.padded_begin()[glid];
628 const auto *a_end = std::min(a_ptr + to_read, a.end());
630 auto _a_load_cond = [a_begin, a_end](
auto &&ptr) {
631 return ptr >= a_begin && ptr < a_end;
635 a_data.load(a_ptr, _a_load_cond, 0);
637 const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
638 auto v_size = v.size();
641 v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
643 auto results_num = get_results_num(WorkPI, out.size(), glid, item);
645 process_block(results, results_num, a_data, v_data, v_size, op,
648 auto *
const out_ptr = out.begin();
653 std::min(y_start + WorkPI * results.size_x(), out.size());
655 for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
656 out_ptr[y] = results[i++];
666void validate(
const usm_ndarray &a,
667 const usm_ndarray &v,
668 const usm_ndarray &out,