28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
34#include "ext/common.hpp"
36#include "partitioning.hpp"
38using dpctl::tensor::usm_ndarray;
43using ext::common::make_ndrange;
45namespace statistics::partitioning
52auto partition_one_pivot_func_gpu(sycl::handler &cgh,
60 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
61 auto loc_global_counters =
62 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(2), cgh);
64 sycl::local_accessor<T, 1>(sycl::range<1>(WorkPI * group_size), cgh);
66 return [=](sycl::nd_item<1> item) {
70 auto group = item.get_group();
71 auto group_range = group.get_local_range(0);
72 auto llid = item.get_local_linear_id();
73 uint64_t items_per_group = group.get_local_range(0) * WorkPI;
74 uint64_t num_elems = state.num_elems[0];
76 if (group.get_group_id(0) * items_per_group >= num_elems)
84 _in = in + state.n - num_elems;
87 auto value = state.pivot[0];
89 auto sbg = item.get_sub_group();
91 uint32_t sbg_size = sbg.get_max_local_range()[0];
92 uint32_t sbg_work_size = sbg_size * WorkPI;
93 uint32_t sbg_llid = sbg.get_local_linear_id();
94 uint64_t i_base = (item.get_global_linear_id() - sbg_llid) * WorkPI;
102 sycl::group_barrier(group);
104 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
105 uint32_t less_count = 0;
106 uint32_t equal_count = 0;
107 uint32_t greater_equal_count = 0;
109 uint32_t actual_count = 0;
110 auto i = i_base + _i * sbg_size + sbg_llid;
111 uint32_t valid = i < num_elems;
112 auto val = valid ? _in[i] : 0;
113 uint32_t less = (val < value) && valid;
114 uint32_t equal = (val == value) && valid;
117 sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
118 auto ge_pos = sbg.get_local_linear_id() - le_pos;
119 auto sbg_less_equal =
120 sycl::reduce_over_group(sbg, less, sycl::plus<>());
122 sycl::reduce_over_group(sbg, equal, sycl::plus<>());
124 sycl::reduce_over_group(sbg, valid, sycl::plus<>());
125 auto sbg_greater = tot_valid - sbg_less_equal;
127 uint32_t local_less_offset = 0;
128 uint32_t local_gr_offset = 0;
131 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
132 sycl::memory_scope::work_group>
133 gr_less_eq(loc_counters[0]);
134 local_less_offset = gr_less_eq.fetch_add(sbg_less_equal);
136 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
137 sycl::memory_scope::work_group>
138 gr_eq(loc_counters[1]);
141 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
142 sycl::memory_scope::work_group>
143 gr_greater(loc_counters[2]);
144 local_gr_offset = gr_greater.fetch_add(sbg_greater);
147 uint32_t local_less_offset_ =
148 sycl::select_from_group(sbg, local_less_offset, 0);
149 uint32_t local_gr_offset_ =
150 sycl::select_from_group(sbg, local_gr_offset, 0);
154 uint32_t ll_offset = local_less_offset_ + le_pos;
155 loc_items[ll_offset] = val;
158 auto loc_gr_offset = group_range * WorkPI -
159 local_gr_offset_ - sbg_greater +
161 loc_items[loc_gr_offset] = val;
166 sycl::group_barrier(group);
168 if (group.leader()) {
169 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
170 sycl::memory_scope::device>
171 glbl_less_eq(state.iteration_counters.less_count[0]);
172 auto global_less_eq_offset =
173 glbl_less_eq.fetch_add(loc_counters[0]);
175 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
176 sycl::memory_scope::device>
177 glbl_eq(state.iteration_counters.equal_count[0]);
178 glbl_eq += loc_counters[1];
180 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
181 sycl::memory_scope::device>
182 glbl_greater(state.iteration_counters.greater_equal_count[0]);
183 auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]);
185 loc_global_counters[0] = global_less_eq_offset;
186 loc_global_counters[1] = global_gr_offset + loc_counters[2];
189 sycl::group_barrier(group);
191 auto global_less_eq_offset = loc_global_counters[0];
192 auto global_gr_offset = state.n - loc_global_counters[1];
194 uint32_t sbg_id = sbg.get_group_id();
195 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
196 uint32_t i = sbg_id * sbg_size * WorkPI + _i * sbg_size + sbg_llid;
197 if (i < loc_counters[0]) {
198 out[global_less_eq_offset + i] = loc_items[i];
200 else if (i < loc_counters[0] + loc_counters[2]) {
201 auto global_gr_offset_ = global_gr_offset + i - loc_counters[0];
202 uint32_t local_buff_offset = WorkPI * group_range -
203 loc_counters[2] + i -
206 out[global_gr_offset_] = loc_items[local_buff_offset];
213sycl::event run_partition_one_pivot_gpu(sycl::queue &exec_q,
216 PartitionState<T> &state,
217 const std::vector<sycl::event> &deps,
221 auto e = exec_q.submit([&](sycl::handler &cgh) {
222 cgh.depends_on(deps);
224 auto work_range = make_ndrange(state.n, group_size, WorkPI);
226 cgh.parallel_for<partition_one_pivot_kernel_gpu<T>>(
227 work_range, partition_one_pivot_func_gpu<T>(cgh, in, out, state,
228 group_size, WorkPI));