DPNP C++ backend kernel library 0.18.0rc1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
partitioning_one_pivot_kernel_gpu.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
30#include <type_traits>
31
32#include <stdio.h>
33
34#include "ext/common.hpp"
35
36#include "partitioning.hpp"
37
38using dpctl::tensor::usm_ndarray;
39
43using ext::common::make_ndrange;
44
45namespace statistics::partitioning
46{
47
48template <typename T>
50
51template <typename T>
52auto partition_one_pivot_func_gpu(sycl::handler &cgh,
53 T *in,
54 T *out,
55 PartitionState<T> &state,
56 uint32_t group_size,
57 uint32_t WorkPI)
58{
59 auto loc_counters =
60 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
61 auto loc_global_counters =
62 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(2), cgh);
63 auto loc_items =
64 sycl::local_accessor<T, 1>(sycl::range<1>(WorkPI * group_size), cgh);
65
66 return [=](sycl::nd_item<1> item) {
67 if (state.stop[0])
68 return;
69
70 auto group = item.get_group();
71 auto group_range = group.get_local_range(0);
72 auto llid = item.get_local_linear_id();
73 uint64_t items_per_group = group.get_local_range(0) * WorkPI;
74 uint64_t num_elems = state.num_elems[0];
75
76 if (group.get_group_id(0) * items_per_group >= num_elems)
77 return;
78
79 T *_in = nullptr;
80 if (state.left[0]) {
81 _in = in;
82 }
83 else {
84 _in = in + state.n - num_elems;
85 }
86
87 auto value = state.pivot[0];
88
89 auto sbg = item.get_sub_group();
90
91 uint32_t sbg_size = sbg.get_max_local_range()[0];
92 uint32_t sbg_work_size = sbg_size * WorkPI;
93 uint32_t sbg_llid = sbg.get_local_linear_id();
94 uint64_t i_base = (item.get_global_linear_id() - sbg_llid) * WorkPI;
95
96 if (group.leader()) {
97 loc_counters[0] = 0;
98 loc_counters[1] = 0;
99 loc_counters[2] = 0;
100 }
101
102 sycl::group_barrier(group);
103
104 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
105 uint32_t less_count = 0;
106 uint32_t equal_count = 0;
107 uint32_t greater_equal_count = 0;
108
109 uint32_t actual_count = 0;
110 auto i = i_base + _i * sbg_size + sbg_llid;
111 uint32_t valid = i < num_elems;
112 auto val = valid ? _in[i] : 0;
113 uint32_t less = (val < value) && valid;
114 uint32_t equal = (val == value) && valid;
115
116 auto le_pos =
117 sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
118 auto ge_pos = sbg.get_local_linear_id() - le_pos;
119 auto sbg_less_equal =
120 sycl::reduce_over_group(sbg, less, sycl::plus<>());
121 auto sbg_equal =
122 sycl::reduce_over_group(sbg, equal, sycl::plus<>());
123 auto tot_valid =
124 sycl::reduce_over_group(sbg, valid, sycl::plus<>());
125 auto sbg_greater = tot_valid - sbg_less_equal;
126
127 uint32_t local_less_offset = 0;
128 uint32_t local_gr_offset = 0;
129
130 if (sbg.leader()) {
131 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
132 sycl::memory_scope::work_group>
133 gr_less_eq(loc_counters[0]);
134 local_less_offset = gr_less_eq.fetch_add(sbg_less_equal);
135
136 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
137 sycl::memory_scope::work_group>
138 gr_eq(loc_counters[1]);
139 gr_eq += sbg_equal;
140
141 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
142 sycl::memory_scope::work_group>
143 gr_greater(loc_counters[2]);
144 local_gr_offset = gr_greater.fetch_add(sbg_greater);
145 }
146
147 uint32_t local_less_offset_ =
148 sycl::select_from_group(sbg, local_less_offset, 0);
149 uint32_t local_gr_offset_ =
150 sycl::select_from_group(sbg, local_gr_offset, 0);
151
152 if (valid) {
153 if (less) {
154 uint32_t ll_offset = local_less_offset_ + le_pos;
155 loc_items[ll_offset] = val;
156 }
157 else {
158 auto loc_gr_offset = group_range * WorkPI -
159 local_gr_offset_ - sbg_greater +
160 ge_pos;
161 loc_items[loc_gr_offset] = val;
162 }
163 }
164 }
165
166 sycl::group_barrier(group);
167
168 if (group.leader()) {
169 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
170 sycl::memory_scope::device>
171 glbl_less_eq(state.iteration_counters.less_count[0]);
172 auto global_less_eq_offset =
173 glbl_less_eq.fetch_add(loc_counters[0]);
174
175 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
176 sycl::memory_scope::device>
177 glbl_eq(state.iteration_counters.equal_count[0]);
178 glbl_eq += loc_counters[1];
179
180 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
181 sycl::memory_scope::device>
182 glbl_greater(state.iteration_counters.greater_equal_count[0]);
183 auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]);
184
185 loc_global_counters[0] = global_less_eq_offset;
186 loc_global_counters[1] = global_gr_offset + loc_counters[2];
187 }
188
189 sycl::group_barrier(group);
190
191 auto global_less_eq_offset = loc_global_counters[0];
192 auto global_gr_offset = state.n - loc_global_counters[1];
193
194 uint32_t sbg_id = sbg.get_group_id();
195 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
196 uint32_t i = sbg_id * sbg_size * WorkPI + _i * sbg_size + sbg_llid;
197 if (i < loc_counters[0]) {
198 out[global_less_eq_offset + i] = loc_items[i];
199 }
200 else if (i < loc_counters[0] + loc_counters[2]) {
201 auto global_gr_offset_ = global_gr_offset + i - loc_counters[0];
202 uint32_t local_buff_offset = WorkPI * group_range -
203 loc_counters[2] + i -
204 loc_counters[0];
205
206 out[global_gr_offset_] = loc_items[local_buff_offset];
207 }
208 }
209 };
210}
211
212template <typename T>
213sycl::event run_partition_one_pivot_gpu(sycl::queue &exec_q,
214 T *in,
215 T *out,
216 PartitionState<T> &state,
217 const std::vector<sycl::event> &deps,
218 uint32_t group_size,
219 uint32_t WorkPI)
220{
221 auto e = exec_q.submit([&](sycl::handler &cgh) {
222 cgh.depends_on(deps);
223
224 auto work_range = make_ndrange(state.n, group_size, WorkPI);
225
226 cgh.parallel_for<partition_one_pivot_kernel_gpu<T>>(
227 work_range, partition_one_pivot_func_gpu<T>(cgh, in, out, state,
228 group_size, WorkPI));
229 });
230
231 return e;
232}
233
234} // namespace statistics::partitioning