28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
34#include "ext/common.hpp"
36using dpctl::tensor::usm_ndarray;
41using ext::common::make_ndrange;
43namespace statistics::partitioning
49 uint64_t *equal_count;
50 uint64_t *greater_equal_count;
55 less_count = sycl::malloc_device<uint64_t>(1, queue);
56 equal_count = sycl::malloc_device<uint64_t>(1, queue);
57 greater_equal_count = sycl::malloc_device<uint64_t>(1, queue);
58 nan_count = sycl::malloc_device<uint64_t>(1, queue);
61 void cleanup(sycl::queue &queue)
63 sycl::free(less_count, queue);
64 sycl::free(equal_count, queue);
65 sycl::free(greater_equal_count, queue);
66 sycl::free(nan_count, queue);
87 State(sycl::queue &queue,
size_t _n, T *values_buff)
88 : counters(queue), iteration_counters(queue)
90 stop = sycl::malloc_device<bool>(1, queue);
91 target_found = sycl::malloc_device<bool>(1, queue);
92 left = sycl::malloc_device<bool>(1, queue);
94 pivot = sycl::malloc_device<T>(1, queue);
97 num_elems = sycl::malloc_device<size_t>(1, queue);
102 sycl::event init(sycl::queue &queue,
const std::vector<sycl::event> &deps)
105 queue.fill<uint64_t>(counters.less_count, 0, 1, deps);
106 fill_e = queue.fill<uint64_t>(counters.equal_count, 0, 1, {fill_e});
108 queue.fill<uint64_t>(counters.greater_equal_count, n, 1, {fill_e});
109 fill_e = queue.fill<uint64_t>(counters.nan_count, 0, 1, {fill_e});
110 fill_e = queue.fill<uint64_t>(num_elems, 0, 1, {fill_e});
111 fill_e = queue.fill<
bool>(stop,
false, 1, {fill_e});
112 fill_e = queue.fill<
bool>(target_found,
false, 1, {fill_e});
113 fill_e = queue.fill<
bool>(left,
false, 1, {fill_e});
114 fill_e = queue.fill<T>(pivot, 0, 1, {fill_e});
119 void update_counters()
const
122 counters.less_count[0] -= iteration_counters.greater_equal_count[0];
123 counters.greater_equal_count[0] +=
124 iteration_counters.greater_equal_count[0];
127 counters.less_count[0] += iteration_counters.less_count[0];
128 counters.greater_equal_count[0] -= iteration_counters.less_count[0];
130 counters.equal_count[0] = iteration_counters.equal_count[0];
131 counters.nan_count[0] += iteration_counters.nan_count[0];
134 void reset_iteration_counters()
const
136 iteration_counters.less_count[0] = 0;
137 iteration_counters.equal_count[0] = 0;
138 iteration_counters.greater_equal_count[0] = 0;
139 iteration_counters.nan_count[0] = 0;
142 void cleanup(sycl::queue &queue)
144 counters.cleanup(queue);
145 iteration_counters.cleanup(queue);
147 sycl::free(stop, queue);
148 sycl::free(target_found, queue);
149 sycl::free(left, queue);
151 sycl::free(num_elems, queue);
152 sycl::free(pivot, queue);
170 : iteration_counters(state.iteration_counters)
175 num_elems = state.num_elems;
181 sycl::event init(sycl::queue &queue,
const std::vector<sycl::event> &deps)
184 queue.fill<uint64_t>(iteration_counters.less_count, n, 1, deps);
185 fill_e = queue.fill<uint64_t>(iteration_counters.equal_count, 0, 1,
187 fill_e = queue.fill<uint64_t>(iteration_counters.greater_equal_count, 0,
190 queue.fill<uint64_t>(iteration_counters.nan_count, 0, 1, {fill_e});
199template <
typename T, u
int32_t WorkPI>
200void submit_partition_one_pivot(sycl::handler &cgh,
201 sycl::nd_range<1> work_sz,
207 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
210 work_sz, [=](sycl::nd_item<1> item) {
214 auto group = item.get_group();
215 uint64_t items_per_group = group.get_local_range(0) * WorkPI;
216 uint64_t num_elems = state.num_elems[0];
218 if (group.get_group_id(0) * items_per_group >= num_elems)
226 _in = in + state.n - num_elems;
229 auto value = state.pivot[0];
231 auto sbg = item.get_sub_group();
232 uint32_t sbg_size = sbg.get_max_local_range()[0];
235 (item.get_global_linear_id() - sbg.get_local_linear_id()) *
238 if (group.leader()) {
244 sycl::group_barrier(group);
246 uint32_t less_count = 0;
247 uint32_t equal_count = 0;
248 uint32_t greater_equal_count = 0;
249 uint32_t nan_count = 0;
252 uint32_t actual_count = 0;
253 uint64_t local_i_base = i_base + sbg.get_local_linear_id();
255 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
256 auto i = local_i_base + _i * sbg_size;
259 less_count +=
Less<T>{}(values[_i], value);
260 equal_count += values[_i] == value;
266 greater_equal_count = actual_count - less_count;
268 auto sbg_less_equal =
269 sycl::reduce_over_group(sbg, less_count, sycl::plus<>());
271 sycl::reduce_over_group(sbg, equal_count, sycl::plus<>());
272 auto sbg_greater = sycl::reduce_over_group(sbg, greater_equal_count,
275 uint32_t local_less_offset = 0;
276 uint32_t local_gr_offset = 0;
278 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
279 sycl::memory_scope::work_group>
280 gr_less_eq(loc_counters[0]);
281 local_less_offset = gr_less_eq.fetch_add(sbg_less_equal);
283 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
284 sycl::memory_scope::work_group>
285 gr_eq(loc_counters[1]);
288 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
289 sycl::memory_scope::work_group>
290 gr_greater(loc_counters[2]);
291 local_gr_offset = gr_greater.fetch_add(sbg_greater);
295 sycl::select_from_group(sbg, local_less_offset, 0);
296 local_gr_offset = sycl::select_from_group(sbg, local_gr_offset, 0);
298 sycl::group_barrier(group);
300 if (group.leader()) {
301 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
302 sycl::memory_scope::device>
303 glbl_less_eq(state.iteration_counters.less_count[0]);
304 auto global_less_eq_offset =
305 glbl_less_eq.fetch_add(loc_counters[0]);
307 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
308 sycl::memory_scope::device>
309 glbl_eq(state.iteration_counters.equal_count[0]);
310 glbl_eq += loc_counters[1];
312 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
313 sycl::memory_scope::device>
315 state.iteration_counters.greater_equal_count[0]);
316 auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]);
318 loc_counters[0] = global_less_eq_offset;
319 loc_counters[2] = global_gr_offset;
322 sycl::group_barrier(group);
324 auto sbg_less_offset = loc_counters[0] + local_less_offset;
326 state.n - (loc_counters[2] + local_gr_offset + sbg_greater);
328 uint32_t le_item_offset = 0;
329 uint32_t gr_item_offset = 0;
331 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
332 uint32_t less = values[_i] < value;
334 sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
335 auto ge_pos = sbg.get_local_linear_id() - le_pos;
338 sycl::reduce_over_group(sbg, less, sycl::plus<>());
339 auto total_gr = sbg_size - total_le;
341 if (_i < actual_count) {
343 out[sbg_less_offset + le_item_offset + le_pos] =
347 out[sbg_gr_offset + gr_item_offset + ge_pos] =
350 le_item_offset += total_le;
351 gr_item_offset += total_gr;