DPNP C++ backend kernel library 0.18.0rc1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
partitioning.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
30#include <type_traits>
31
32#include <stdio.h>
33
34#include "ext/common.hpp"
35
36using dpctl::tensor::usm_ndarray;
37
41using ext::common::make_ndrange;
42
43namespace statistics::partitioning
44{
45
47{
48 uint64_t *less_count;
49 uint64_t *equal_count;
50 uint64_t *greater_equal_count;
51 uint64_t *nan_count;
52
53 Counters(sycl::queue &queue)
54 {
55 less_count = sycl::malloc_device<uint64_t>(1, queue);
56 equal_count = sycl::malloc_device<uint64_t>(1, queue);
57 greater_equal_count = sycl::malloc_device<uint64_t>(1, queue);
58 nan_count = sycl::malloc_device<uint64_t>(1, queue);
59 };
60
61 void cleanup(sycl::queue &queue)
62 {
63 sycl::free(less_count, queue);
64 sycl::free(equal_count, queue);
65 sycl::free(greater_equal_count, queue);
66 sycl::free(nan_count, queue);
67 }
68};
69
70template <typename T>
71struct State
72{
73 Counters counters;
74 Counters iteration_counters;
75
76 bool *stop;
77 bool *target_found;
78 bool *left;
79
80 T *pivot;
81 T *values;
82
83 size_t *num_elems;
84
85 size_t n;
86
87 State(sycl::queue &queue, size_t _n, T *values_buff)
88 : counters(queue), iteration_counters(queue)
89 {
90 stop = sycl::malloc_device<bool>(1, queue);
91 target_found = sycl::malloc_device<bool>(1, queue);
92 left = sycl::malloc_device<bool>(1, queue);
93
94 pivot = sycl::malloc_device<T>(1, queue);
95 values = values_buff;
96
97 num_elems = sycl::malloc_device<size_t>(1, queue);
98
99 n = _n;
100 }
101
102 sycl::event init(sycl::queue &queue, const std::vector<sycl::event> &deps)
103 {
104 sycl::event fill_e =
105 queue.fill<uint64_t>(counters.less_count, 0, 1, deps);
106 fill_e = queue.fill<uint64_t>(counters.equal_count, 0, 1, {fill_e});
107 fill_e =
108 queue.fill<uint64_t>(counters.greater_equal_count, n, 1, {fill_e});
109 fill_e = queue.fill<uint64_t>(counters.nan_count, 0, 1, {fill_e});
110 fill_e = queue.fill<uint64_t>(num_elems, 0, 1, {fill_e});
111 fill_e = queue.fill<bool>(stop, false, 1, {fill_e});
112 fill_e = queue.fill<bool>(target_found, false, 1, {fill_e});
113 fill_e = queue.fill<bool>(left, false, 1, {fill_e});
114 fill_e = queue.fill<T>(pivot, 0, 1, {fill_e});
115
116 return fill_e;
117 }
118
119 void update_counters() const
120 {
121 if (*left) {
122 counters.less_count[0] -= iteration_counters.greater_equal_count[0];
123 counters.greater_equal_count[0] +=
124 iteration_counters.greater_equal_count[0];
125 }
126 else {
127 counters.less_count[0] += iteration_counters.less_count[0];
128 counters.greater_equal_count[0] -= iteration_counters.less_count[0];
129 }
130 counters.equal_count[0] = iteration_counters.equal_count[0];
131 counters.nan_count[0] += iteration_counters.nan_count[0];
132 }
133
134 void reset_iteration_counters() const
135 {
136 iteration_counters.less_count[0] = 0;
137 iteration_counters.equal_count[0] = 0;
138 iteration_counters.greater_equal_count[0] = 0;
139 iteration_counters.nan_count[0] = 0;
140 }
141
142 void cleanup(sycl::queue &queue)
143 {
144 counters.cleanup(queue);
145 iteration_counters.cleanup(queue);
146
147 sycl::free(stop, queue);
148 sycl::free(target_found, queue);
149 sycl::free(left, queue);
150
151 sycl::free(num_elems, queue);
152 sycl::free(pivot, queue);
153 }
154};
155
156template <typename T>
158{
159 Counters iteration_counters;
160
161 bool *stop;
162 bool *left;
163
164 T *pivot;
165
166 size_t n;
167 size_t *num_elems;
168
170 : iteration_counters(state.iteration_counters)
171 {
172 stop = state.stop;
173 left = state.left;
174
175 num_elems = state.num_elems;
176 pivot = state.pivot;
177
178 n = state.n;
179 }
180
181 sycl::event init(sycl::queue &queue, const std::vector<sycl::event> &deps)
182 {
183 sycl::event fill_e =
184 queue.fill<uint64_t>(iteration_counters.less_count, n, 1, deps);
185 fill_e = queue.fill<uint64_t>(iteration_counters.equal_count, 0, 1,
186 {fill_e});
187 fill_e = queue.fill<uint64_t>(iteration_counters.greater_equal_count, 0,
188 1, {fill_e});
189 fill_e =
190 queue.fill<uint64_t>(iteration_counters.nan_count, 0, 1, {fill_e});
191
192 return fill_e;
193 }
194};
195
196} // namespace statistics::partitioning
197
198#include "partitioning_one_pivot_kernel_cpu.hpp"
199#include "partitioning_one_pivot_kernel_gpu.hpp"
200
201namespace statistics::partitioning
202{
203template <typename T>
204sycl::event run_partition_one_pivot(sycl::queue &exec_q,
205 T *in,
206 T *out,
207 PartitionState<T> &state,
208 const std::vector<sycl::event> &deps)
209{
210 auto device = exec_q.get_device();
211
212 if (device.is_gpu()) {
213 constexpr uint32_t WorkPI = 8;
214 constexpr uint32_t group_size = 128;
215
216 return run_partition_one_pivot_gpu<T>(exec_q, in, out, state, deps,
217 group_size, WorkPI);
218 }
219 else {
220 constexpr uint32_t WorkPI = 4;
221 constexpr uint32_t group_size = 128;
222
223 return run_partition_one_pivot_cpu<T, WorkPI>(exec_q, in, out, state,
224 deps, group_size);
225 }
226}
227
228void validate(const usm_ndarray &a,
229 const usm_ndarray &partitioned,
230 const size_t k);
231} // namespace statistics::partitioning