28#include "utils/math_utils.hpp"
29#include <sycl/sycl.hpp>
34#include "ext/common.hpp"
36#include "partitioning.hpp"
38using dpctl::tensor::usm_ndarray;
43using ext::common::make_ndrange;
45namespace statistics::partitioning
48template <
typename T, u
int32_t WorkPI>
51template <
typename T, u
int32_t WorkPI>
52auto partition_one_pivot_func_cpu(sycl::handler &cgh,
58 sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
60 return [=](sycl::nd_item<1> item) {
64 auto group = item.get_group();
65 uint64_t items_per_group = group.get_local_range(0) * WorkPI;
66 uint64_t num_elems = state.num_elems[0];
68 if (group.get_group_id(0) * items_per_group >= num_elems)
76 _in = in + state.n - num_elems;
79 auto value = state.pivot[0];
81 auto sbg = item.get_sub_group();
82 uint32_t sbg_size = sbg.get_max_local_range()[0];
85 (item.get_global_linear_id() - sbg.get_local_linear_id()) * WorkPI;
93 sycl::group_barrier(group);
95 uint32_t less_count = 0;
96 uint32_t equal_count = 0;
97 uint32_t greater_equal_count = 0;
98 uint32_t nan_count = 0;
101 uint32_t actual_count = 0;
102 uint64_t local_i_base = i_base + sbg.get_local_linear_id();
104 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
105 auto i = local_i_base + _i * sbg_size;
109 less_count += (
Less<T>{}(values[_i], value) && !is_nan);
110 equal_count += (values[_i] == value && !is_nan);
116 greater_equal_count = actual_count - less_count - nan_count;
118 auto sbg_less_equal =
119 sycl::reduce_over_group(sbg, less_count, sycl::plus<>());
121 sycl::reduce_over_group(sbg, equal_count, sycl::plus<>());
123 sycl::reduce_over_group(sbg, greater_equal_count, sycl::plus<>());
125 uint32_t local_less_offset = 0;
126 uint32_t local_gr_offset = 0;
128 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
129 sycl::memory_scope::work_group>
130 gr_less_eq(loc_counters[0]);
131 local_less_offset = gr_less_eq.fetch_add(sbg_less_equal);
133 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
134 sycl::memory_scope::work_group>
135 gr_eq(loc_counters[1]);
138 sycl::atomic_ref<uint32_t, sycl::memory_order::relaxed,
139 sycl::memory_scope::work_group>
140 gr_greater(loc_counters[2]);
141 local_gr_offset = gr_greater.fetch_add(sbg_greater);
144 local_less_offset = sycl::select_from_group(sbg, local_less_offset, 0);
145 local_gr_offset = sycl::select_from_group(sbg, local_gr_offset, 0);
147 sycl::group_barrier(group);
149 if (group.leader()) {
150 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
151 sycl::memory_scope::device>
152 glbl_less_eq(state.iteration_counters.less_count[0]);
153 auto global_less_eq_offset =
154 glbl_less_eq.fetch_add(loc_counters[0]);
156 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
157 sycl::memory_scope::device>
158 glbl_eq(state.iteration_counters.equal_count[0]);
159 glbl_eq += loc_counters[1];
161 sycl::atomic_ref<uint64_t, sycl::memory_order::relaxed,
162 sycl::memory_scope::device>
163 glbl_greater(state.iteration_counters.greater_equal_count[0]);
164 auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]);
166 loc_counters[0] = global_less_eq_offset;
167 loc_counters[2] = global_gr_offset;
170 sycl::group_barrier(group);
172 auto sbg_less_offset = loc_counters[0] + local_less_offset;
174 state.n - (loc_counters[2] + local_gr_offset + sbg_greater);
176 uint32_t le_item_offset = 0;
177 uint32_t gr_item_offset = 0;
179 for (uint32_t _i = 0; _i < WorkPI; ++_i) {
181 uint32_t less = (!is_nan &&
Less<T>{}(values[_i], value));
183 sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
184 auto ge_pos = sbg.get_local_linear_id() - le_pos;
186 auto total_le = sycl::reduce_over_group(sbg, less, sycl::plus<>());
188 sycl::reduce_over_group(sbg, is_nan, sycl::plus<>());
189 auto total_gr = sbg_size - total_le - total_nan;
191 if (_i < actual_count) {
193 out[sbg_less_offset + le_item_offset + le_pos] = values[_i];
196 out[sbg_gr_offset + gr_item_offset + ge_pos] = values[_i];
198 le_item_offset += total_le;
199 gr_item_offset += total_gr;
205template <
typename T, u
int32_t WorkPI>
206sycl::event run_partition_one_pivot_cpu(sycl::queue &exec_q,
209 PartitionState<T> &state,
210 const std::vector<sycl::event> &deps,
213 auto e = exec_q.submit([&](sycl::handler &cgh) {
214 cgh.depends_on(deps);
216 auto work_range = make_ndrange(state.n, group_size, WorkPI);
218 cgh.parallel_for<partition_one_pivot_kernel_cpu<T, WorkPI>>(
220 partition_one_pivot_func_cpu<T, WorkPI>(cgh, in, out, state));