DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <algorithm>
32#include <cstddef>
33#include <cstdint>
34#include <utility>
35#include <vector>
36
37#include <sycl/sycl.hpp>
38
39// dpctl tensor headers
40#include "kernels/alignment.hpp"
41#include "kernels/elementwise_functions/common.hpp"
42#include "utils/sycl_utils.hpp"
43
44namespace dpnp::extensions::py_internal::elementwise_common
45{
46using dpctl::tensor::kernels::alignment_utils::
47 disabled_sg_loadstore_wrapper_krn;
48using dpctl::tensor::kernels::alignment_utils::is_aligned;
49using dpctl::tensor::kernels::alignment_utils::required_alignment;
50
51using dpctl::tensor::kernels::elementwise_common::select_lws;
52
53using dpctl::tensor::sycl_utils::sub_group_load;
54using dpctl::tensor::sycl_utils::sub_group_store;
55
63template <typename argT,
64 typename resT1,
65 typename resT2,
66 typename UnaryTwoOutputsOpT,
67 std::uint8_t vec_sz = 4u,
68 std::uint8_t n_vecs = 2u,
69 bool enable_sg_loadstore = true>
71{
72private:
73 const argT *in = nullptr;
74 resT1 *out1 = nullptr;
75 resT2 *out2 = nullptr;
76 std::size_t nelems_;
77
78public:
79 UnaryTwoOutputsContigFunctor(const argT *inp,
80 resT1 *res1,
81 resT2 *res2,
82 const std::size_t n_elems)
83 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
84 {
85 }
86
87 void operator()(sycl::nd_item<1> ndit) const
88 {
89 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
90 UnaryTwoOutputsOpT op{};
91 /* Each work-item processes vec_sz elements, contiguous in memory */
92 /* NOTE: work-group size must be divisible by sub-group size */
93
94 if constexpr (enable_sg_loadstore &&
95 UnaryTwoOutputsOpT::is_constant::value) {
96 // value of operator is known to be a known constant
97 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
98 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
99
100 auto sg = ndit.get_sub_group();
101 const std::uint16_t sgSize = sg.get_max_local_range()[0];
102
103 const std::size_t base =
104 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
105 sg.get_group_id()[0] * sgSize);
106 if (base + elems_per_wi * sgSize < nelems_) {
107 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
108 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
109#pragma unroll
110 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
111 const std::size_t offset = base + it * sgSize;
112 auto out1_multi_ptr = sycl::address_space_cast<
113 sycl::access::address_space::global_space,
114 sycl::access::decorated::yes>(&out1[offset]);
115 auto out2_multi_ptr = sycl::address_space_cast<
116 sycl::access::address_space::global_space,
117 sycl::access::decorated::yes>(&out2[offset]);
118
119 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
120 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
121 }
122 }
123 else {
124 const std::size_t lane_id = sg.get_local_id()[0];
125 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
126 out1[k] = const_val1;
127 out2[k] = const_val2;
128 }
129 }
130 }
131 else if constexpr (enable_sg_loadstore &&
132 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
133 UnaryTwoOutputsOpT::supports_vec::value &&
134 (vec_sz > 1))
135 {
136 auto sg = ndit.get_sub_group();
137 const std::uint16_t sgSize = sg.get_max_local_range()[0];
138
139 const std::size_t base =
140 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
141 sg.get_group_id()[0] * sgSize);
142 if (base + elems_per_wi * sgSize < nelems_) {
143#pragma unroll
144 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
145 const std::size_t offset = base + it * sgSize;
146 auto in_multi_ptr = sycl::address_space_cast<
147 sycl::access::address_space::global_space,
148 sycl::access::decorated::yes>(&in[offset]);
149 auto out1_multi_ptr = sycl::address_space_cast<
150 sycl::access::address_space::global_space,
151 sycl::access::decorated::yes>(&out1[offset]);
152 auto out2_multi_ptr = sycl::address_space_cast<
153 sycl::access::address_space::global_space,
154 sycl::access::decorated::yes>(&out2[offset]);
155
156 const sycl::vec<argT, vec_sz> x =
157 sub_group_load<vec_sz>(sg, in_multi_ptr);
158 sycl::vec<resT2, vec_sz> res2_vec = {};
159 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
160 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
161 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
162 }
163 }
164 else {
165 const std::size_t lane_id = sg.get_local_id()[0];
166 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
167 // scalar call
168 out1[k] = op(in[k], out2[k]);
169 }
170 }
171 }
172 else if constexpr (enable_sg_loadstore &&
173 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
174 std::is_same_v<resT1, argT>)
175 {
176 // default: use scalar-value function
177
178 auto sg = ndit.get_sub_group();
179 const std::uint16_t sgSize = sg.get_max_local_range()[0];
180 const std::size_t base =
181 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
182 sg.get_group_id()[0] * sgSize);
183
184 if (base + elems_per_wi * sgSize < nelems_) {
185#pragma unroll
186 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
187 const std::size_t offset = base + it * sgSize;
188 auto in_multi_ptr = sycl::address_space_cast<
189 sycl::access::address_space::global_space,
190 sycl::access::decorated::yes>(&in[offset]);
191 auto out1_multi_ptr = sycl::address_space_cast<
192 sycl::access::address_space::global_space,
193 sycl::access::decorated::yes>(&out1[offset]);
194 auto out2_multi_ptr = sycl::address_space_cast<
195 sycl::access::address_space::global_space,
196 sycl::access::decorated::yes>(&out2[offset]);
197
198 sycl::vec<argT, vec_sz> arg_vec =
199 sub_group_load<vec_sz>(sg, in_multi_ptr);
200 sycl::vec<resT2, vec_sz> res2_vec = {};
201#pragma unroll
202 for (std::uint32_t k = 0; k < vec_sz; ++k) {
203 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
204 }
205 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
206 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
207 }
208 }
209 else {
210 const std::size_t lane_id = sg.get_local_id()[0];
211 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
212 out1[k] = op(in[k], out2[k]);
213 }
214 }
215 }
216 else if constexpr (enable_sg_loadstore &&
217 UnaryTwoOutputsOpT::supports_sg_loadstore::value)
218 {
219 // default: use scalar-value function
220
221 auto sg = ndit.get_sub_group();
222 const std::uint16_t sgSize = sg.get_max_local_range()[0];
223 const std::size_t base =
224 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
225 sg.get_group_id()[0] * sgSize);
226
227 if (base + elems_per_wi * sgSize < nelems_) {
228#pragma unroll
229 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
230 const std::size_t offset = base + it * sgSize;
231 auto in_multi_ptr = sycl::address_space_cast<
232 sycl::access::address_space::global_space,
233 sycl::access::decorated::yes>(&in[offset]);
234 auto out1_multi_ptr = sycl::address_space_cast<
235 sycl::access::address_space::global_space,
236 sycl::access::decorated::yes>(&out1[offset]);
237 auto out2_multi_ptr = sycl::address_space_cast<
238 sycl::access::address_space::global_space,
239 sycl::access::decorated::yes>(&out2[offset]);
240
241 const sycl::vec<argT, vec_sz> arg_vec =
242 sub_group_load<vec_sz>(sg, in_multi_ptr);
243 sycl::vec<resT1, vec_sz> res1_vec = {};
244 sycl::vec<resT2, vec_sz> res2_vec = {};
245#pragma unroll
246 for (std::uint8_t k = 0; k < vec_sz; ++k) {
247 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
248 }
249 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
250 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
251 }
252 }
253 else {
254 const std::size_t lane_id = sg.get_local_id()[0];
255 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
256 out1[k] = op(in[k], out2[k]);
257 }
258 }
259 }
260 else {
261 const std::uint16_t sgSize =
262 ndit.get_sub_group().get_local_range()[0];
263 const std::size_t gid = ndit.get_global_linear_id();
264 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
265
266 const std::size_t start =
267 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
268 const std::size_t end = std::min(nelems_, start + elems_per_sg);
269 for (std::size_t offset = start; offset < end; offset += sgSize) {
270 out1[offset] = op(in[offset], out2[offset]);
271 }
272 }
273 }
274};
275
283template <typename argT,
284 typename resT1,
285 typename resT2,
286 typename IndexerT,
287 typename UnaryTwoOutputsOpT>
289{
290private:
291 const argT *inp_ = nullptr;
292 resT1 *res1_ = nullptr;
293 resT2 *res2_ = nullptr;
294 IndexerT inp_out_indexer_;
295
296public:
297 UnaryTwoOutputsStridedFunctor(const argT *inp_p,
298 resT1 *res1_p,
299 resT2 *res2_p,
300 const IndexerT &inp_out_indexer)
301 : inp_(inp_p), res1_(res1_p), res2_(res2_p),
302 inp_out_indexer_(inp_out_indexer)
303 {
304 }
305
306 void operator()(sycl::id<1> wid) const
307 {
308 const auto &offsets_ = inp_out_indexer_(wid.get(0));
309 const ssize_t &inp_offset = offsets_.get_first_offset();
310 const ssize_t &res1_offset = offsets_.get_second_offset();
311 const ssize_t &res2_offset = offsets_.get_third_offset();
312
313 UnaryTwoOutputsOpT op{};
314
315 res1_[res1_offset] = op(inp_[inp_offset], res2_[res2_offset]);
316 }
317};
318
326template <typename argTy,
327 template <typename T>
328 class UnaryTwoOutputsType,
329 template <typename A,
330 typename R1,
331 typename R2,
332 std::uint8_t vs,
333 std::uint8_t nv,
334 bool enable>
335 class UnaryTwoOutputsContigFunctorT,
336 template <typename A,
337 typename R1,
338 typename R2,
339 std::uint8_t vs,
340 std::uint8_t nv>
341 class kernel_name,
342 std::uint8_t vec_sz = 4u,
343 std::uint8_t n_vecs = 2u>
344sycl::event
345 unary_two_outputs_contig_impl(sycl::queue &exec_q,
346 std::size_t nelems,
347 const char *arg_p,
348 char *res1_p,
349 char *res2_p,
350 const std::vector<sycl::event> &depends = {})
351{
352 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
353 const std::size_t n_work_items_needed = nelems / elems_per_wi;
354 const std::size_t lws =
355 select_lws(exec_q.get_device(), n_work_items_needed);
356
357 const std::size_t n_groups =
358 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
359 const auto gws_range = sycl::range<1>(n_groups * lws);
360 const auto lws_range = sycl::range<1>(lws);
361
362 using resTy1 = typename UnaryTwoOutputsType<argTy>::value_type1;
363 using resTy2 = typename UnaryTwoOutputsType<argTy>::value_type2;
364 using BaseKernelName = kernel_name<argTy, resTy1, resTy2, vec_sz, n_vecs>;
365
366 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
367 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p);
368 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p);
369
370 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
371 cgh.depends_on(depends);
372
373 if (is_aligned<required_alignment>(arg_p) &&
374 is_aligned<required_alignment>(res1_p) &&
375 is_aligned<required_alignment>(res2_p))
376 {
377 static constexpr bool enable_sg_loadstore = true;
378 using KernelName = BaseKernelName;
379 using Impl =
380 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
381 n_vecs, enable_sg_loadstore>;
382
383 cgh.parallel_for<KernelName>(
384 sycl::nd_range<1>(gws_range, lws_range),
385 Impl(arg_tp, res1_tp, res2_tp, nelems));
386 }
387 else {
388 static constexpr bool disable_sg_loadstore = false;
389 using KernelName =
390 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
391 using Impl =
392 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
393 n_vecs, disable_sg_loadstore>;
394
395 cgh.parallel_for<KernelName>(
396 sycl::nd_range<1>(gws_range, lws_range),
397 Impl(arg_tp, res1_tp, res2_tp, nelems));
398 }
399 });
400
401 return comp_ev;
402}
403
411template <typename argTy,
412 template <typename T>
413 class UnaryTwoOutputsType,
414 template <typename A, typename R1, typename R2, typename I>
415 class UnaryTwoOutputsStridedFunctorT,
416 template <typename A, typename R1, typename R2, typename I>
417 class kernel_name>
418sycl::event unary_two_outputs_strided_impl(
419 sycl::queue &exec_q,
420 std::size_t nelems,
421 int nd,
422 const ssize_t *shape_and_strides,
423 const char *arg_p,
424 ssize_t arg_offset,
425 char *res1_p,
426 ssize_t res1_offset,
427 char *res2_p,
428 ssize_t res2_offset,
429 const std::vector<sycl::event> &depends,
430 const std::vector<sycl::event> &additional_depends)
431{
432 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
433 cgh.depends_on(depends);
434 cgh.depends_on(additional_depends);
435
436 using res1Ty = typename UnaryTwoOutputsType<argTy>::value_type1;
437 using res2Ty = typename UnaryTwoOutputsType<argTy>::value_type2;
438 using IndexerT =
439 typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
440
441 const IndexerT indexer{nd, arg_offset, res1_offset, res2_offset,
442 shape_and_strides};
443
444 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
445 res1Ty *res1_tp = reinterpret_cast<res1Ty *>(res1_p);
446 res2Ty *res2_tp = reinterpret_cast<res2Ty *>(res2_p);
447
448 using Impl =
449 UnaryTwoOutputsStridedFunctorT<argTy, res1Ty, res2Ty, IndexerT>;
450
451 cgh.parallel_for<kernel_name<argTy, res1Ty, res2Ty, IndexerT>>(
452 {nelems}, Impl(arg_tp, res1_tp, res2_tp, indexer));
453 });
454 return comp_ev;
455}
456
457// Typedefs for function pointers
458
459typedef sycl::event (*unary_two_outputs_contig_impl_fn_ptr_t)(
460 sycl::queue &,
461 std::size_t,
462 const char *,
463 char *,
464 char *,
465 const std::vector<sycl::event> &);
466
467typedef sycl::event (*unary_two_outputs_strided_impl_fn_ptr_t)(
468 sycl::queue &,
469 std::size_t,
470 int,
471 const ssize_t *,
472 const char *,
473 ssize_t,
474 char *,
475 ssize_t,
476 char *,
477 ssize_t,
478 const std::vector<sycl::event> &,
479 const std::vector<sycl::event> &);
480
481} // namespace dpnp::extensions::py_internal::elementwise_common
Functor for evaluation of a unary function with two output arrays on contiguous arrays.
Definition common.hpp:71
Functor for evaluation of a unary function with two output arrays on strided data.
Definition common.hpp:289