DPNP C++ backend kernel library 0.18.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
partitioning.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
30#include <type_traits>
31
32#include <stdio.h>
33
34#include "ext/common.hpp"
35
36using dpctl::tensor::usm_ndarray;
37
41using ext::common::make_ndrange;
42
43namespace statistics::partitioning
44{
45
47{
48 uint64_t *less_count;
49 uint64_t *equal_count;
50 uint64_t *greater_equal_count;
51 uint64_t *nan_count;
52
53 Counters(sycl::queue &queue)
54 {
55 less_count = sycl::malloc_device<uint64_t>(1, queue);
56 equal_count = sycl::malloc_device<uint64_t>(1, queue);
57 greater_equal_count = sycl::malloc_device<uint64_t>(1, queue);
58 nan_count = sycl::malloc_device<uint64_t>(1, queue);
59 };
60
61 void cleanup(sycl::queue &queue)
62 {
63 sycl::free(less_count, queue);
64 sycl::free(equal_count, queue);
65 sycl::free(greater_equal_count, queue);
66 sycl::free(nan_count, queue);
67 }
68};
69
70template <typename T>
71struct State
72{
73 Counters counters;
74 Counters iteration_counters;
75
76 bool *stop;
77 bool *target_found;
78 bool *left;
79
80 T *pivot;
81 T *values;
82
83 size_t *num_elems;
84
85 size_t n;
86
87 State(sycl::queue &queue, size_t _n, T *values_buff)
88 : counters(queue), iteration_counters(queue)
89 {
90 stop = sycl::malloc_device<bool>(1, queue);
91 target_found = sycl::malloc_device<bool>(1, queue);
92 left = sycl::malloc_device<bool>(1, queue);
93
94 pivot = sycl::malloc_device<T>(1, queue);
95 values = values_buff;
96
97 num_elems = sycl::malloc_device<size_t>(1, queue);
98
99 n = _n;
100 }
101
102 sycl::event init(sycl::queue &queue, const std::vector<sycl::event> &deps)
103 {
104 sycl::event fill_e =
105 queue.fill<uint64_t>(counters.less_count, 0, 1, deps);
106 fill_e = queue.fill<uint64_t>(counters.equal_count, 0, 1, {fill_e});
107 fill_e =
108 queue.fill<uint64_t>(counters.greater_equal_count, n, 1, {fill_e});
109 fill_e = queue.fill<uint64_t>(counters.nan_count, 0, 1, {fill_e});
110 fill_e = queue.fill<uint64_t>(num_elems, 0, 1, {fill_e});
111 fill_e = queue.fill<bool>(stop, false, 1, {fill_e});
112 fill_e = queue.fill<bool>(target_found, false, 1, {fill_e});
113 fill_e = queue.fill<bool>(left, false, 1, {fill_e});
114 fill_e = queue.fill<T>(pivot, 0, 1, {fill_e});
115
116 return fill_e;
117 }
118
119 void update_counters() const
120 {
121 if (*left) {
122 counters.less_count[0] -= iteration_counters.greater_equal_count[0];
123 counters.greater_equal_count[0] +=
124 iteration_counters.greater_equal_count[0];
125 }
126 else {
127 counters.less_count[0] += iteration_counters.less_count[0];
128 counters.greater_equal_count[0] -= iteration_counters.less_count[0];
129 }
130 counters.equal_count[0] = iteration_counters.equal_count[0];
131 counters.nan_count[0] += iteration_counters.nan_count[0];
132 }
133
134 void reset_iteration_counters() const
135 {
136 iteration_counters.less_count[0] = 0;
137 iteration_counters.equal_count[0] = 0;
138 iteration_counters.greater_equal_count[0] = 0;
139 iteration_counters.nan_count[0] = 0;
140 }
141
142 void cleanup(sycl::queue &queue)
143 {
144 counters.cleanup(queue);
145 iteration_counters.cleanup(queue);
146
147 sycl::free(stop, queue);
148 sycl::free(target_found, queue);
149 sycl::free(left, queue);
150
151 sycl::free(num_elems, queue);
152 sycl::free(pivot, queue);
153 }
154};
155
156template <typename T>
158{
159 Counters iteration_counters;
160
161 bool *stop;
162 bool *left;
163
164 T *pivot;
165
166 size_t n;
167 size_t *num_elems;
168
170 : iteration_counters(state.iteration_counters)
171 {
172 stop = state.stop;
173 left = state.left;
174
175 num_elems = state.num_elems;
176 pivot = state.pivot;
177
178 n = state.n;
179 }
180
181 sycl::event init(sycl::queue &queue, const std::vector<sycl::event> &deps)
182 {
183 sycl::event fill_e =
184 queue.fill<uint64_t>(iteration_counters.less_count, n, 1, deps);
185 fill_e = queue.fill<uint64_t>(iteration_counters.equal_count, 0, 1,
186 {fill_e});
187 fill_e = queue.fill<uint64_t>(iteration_counters.greater_equal_count, 0,
188 1, {fill_e});
189 fill_e =
190 queue.fill<uint64_t>(iteration_counters.nan_count, 0, 1, {fill_e});
191
192 return fill_e;
193 }
194};
195
196template <typename T>
198
199template <typename T, uint32_t WorkPI>
200void submit_partition_one_pivot(sycl::handler &cgh,
201 sycl::nd_range<1> work_sz,
202 T *in,
203 T *out,
204 PartitionState<T> &state)
205{
206 auto loc_counters =
207 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
208 // sycl::stream str(8192, 1024, cgh);
209 cgh.parallel_for<partition_one_pivot_kernel<T>>(
210 work_sz, [=](sycl::nd_item<1> item) {
211 if (state.stop[0])
212 return;
213
214 auto group = item.get_group();
215 uint64_t items_per_group = group.get_local_range(0) * WorkPI;
216 uint64_t num_elems = state.num_elems[0];
217
218 if (group.get_group_id(0) * items_per_group >= num_elems)
219 return;
220
221 T *_in = nullptr;
222 if (state.left[0]) {
223 _in = in;
224 }
225 else {
226 _in = in + state.n - num_elems;
227 }
228
229 auto value = state.pivot[0];
230
231 auto sbg = item.get_sub_group();
232 uint32_t sbg_size = sbg.get_max_local_range()[0];
233
234 uint64_t i_base =
235 (item.get_global_linear_id() - sbg.get_local_linear_id()) *
236 WorkPI;
237
238 if (group.leader()) {
239 loc_counters[0] = 0;
240 loc_counters[1] = 0;
241 loc_counters[2] = 0;
242 }
243
244 sycl::group_barrier(group);
245
246 uint32_t less_count = 0;
247 uint32_t equal_count = 0;
248 uint32_t greater_equal_count = 0;
249 uint32_t nan_count = 0;
250
251 T values[WorkPI];
252 uint32_t actual_count = 0;
253 uint64_t local_i_base = i_base + sbg.get_local_linear_id();
254
255 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
256 auto i = local_i_base + _i * sbg_size;
257 if (i < num_elems) {
258 values[_i] = _in[i];
259 less_count += Less<T>{}(values[_i], value);
260 equal_count += values[_i] == value;
261 nan_count += IsNan<T>::isnan(values[_i]);
262 actual_count++;
263 }
264 }
265
266 greater_equal_count = actual_count - less_count;
267
268 auto sbg_less_equal =
269 sycl::reduce_over_group(sbg, less_count, sycl::plus<>());
270 auto sbg_equal =
271 sycl::reduce_over_group(sbg, equal_count, sycl::plus<>());
272 auto sbg_greater = sycl::reduce_over_group(sbg, greater_equal_count,
273 sycl::plus<>());
274
275 uint32_t local_less_offset = 0;
276 uint32_t local_gr_offset = 0;
277 if (sbg.leader()) {
278 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
279 sycl::memory_scope::work_group>
280 gr_less_eq(loc_counters[0]);
281 local_less_offset = gr_less_eq.fetch_add(sbg_less_equal);
282
283 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
284 sycl::memory_scope::work_group>
285 gr_eq(loc_counters[1]);
286 gr_eq += sbg_equal;
287
288 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
289 sycl::memory_scope::work_group>
290 gr_greater(loc_counters[2]);
291 local_gr_offset = gr_greater.fetch_add(sbg_greater);
292 }
293
294 local_less_offset =
295 sycl::select_from_group(sbg, local_less_offset, 0);
296 local_gr_offset = sycl::select_from_group(sbg, local_gr_offset, 0);
297
298 sycl::group_barrier(group);
299
300 if (group.leader()) {
301 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
302 sycl::memory_scope::device>
303 glbl_less_eq(state.iteration_counters.less_count[0]);
304 auto global_less_eq_offset =
305 glbl_less_eq.fetch_add(loc_counters[0]);
306
307 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
308 sycl::memory_scope::device>
309 glbl_eq(state.iteration_counters.equal_count[0]);
310 glbl_eq += loc_counters[1];
311
312 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
313 sycl::memory_scope::device>
314 glbl_greater(
315 state.iteration_counters.greater_equal_count[0]);
316 auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]);
317
318 loc_counters[0] = global_less_eq_offset;
319 loc_counters[2] = global_gr_offset;
320 }
321
322 sycl::group_barrier(group);
323
324 auto sbg_less_offset = loc_counters[0] + local_less_offset;
325 auto sbg_gr_offset =
326 state.n - (loc_counters[2] + local_gr_offset + sbg_greater);
327
328 uint32_t le_item_offset = 0;
329 uint32_t gr_item_offset = 0;
330
331 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
332 uint32_t less = values[_i] < value;
333 auto le_pos =
334 sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
335 auto ge_pos = sbg.get_local_linear_id() - le_pos;
336
337 auto total_le =
338 sycl::reduce_over_group(sbg, less, sycl::plus<>());
339 auto total_gr = sbg_size - total_le;
340
341 if (_i < actual_count) {
342 if (less) {
343 out[sbg_less_offset + le_item_offset + le_pos] =
344 values[_i];
345 }
346 else {
347 out[sbg_gr_offset + gr_item_offset + ge_pos] =
348 values[_i];
349 }
350 le_item_offset += total_le;
351 gr_item_offset += total_gr;
352 }
353 }
354 });
355}
356
357} // namespace statistics::partitioning