DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <algorithm>
32#include <cstddef>
33#include <cstdint>
34#include <utility>
35#include <vector>
36
37#include <sycl/sycl.hpp>
38
39// dpctl tensor headers
40#include "kernels/alignment.hpp"
41#include "kernels/elementwise_functions/common.hpp"
42#include "utils/sycl_utils.hpp"
43
44namespace dpnp::extensions::py_internal::elementwise_common
45{
46using dpctl::tensor::kernels::alignment_utils::
47 disabled_sg_loadstore_wrapper_krn;
48using dpctl::tensor::kernels::alignment_utils::is_aligned;
49using dpctl::tensor::kernels::alignment_utils::required_alignment;
50
51using dpctl::tensor::kernels::elementwise_common::select_lws;
52
53using dpctl::tensor::sycl_utils::sub_group_load;
54using dpctl::tensor::sycl_utils::sub_group_store;
55
63template <typename argT,
64 typename resT1,
65 typename resT2,
66 typename UnaryTwoOutputsOpT,
67 std::uint8_t vec_sz = 4u,
68 std::uint8_t n_vecs = 2u,
69 bool enable_sg_loadstore = true>
71{
72private:
73 const argT *in = nullptr;
74 resT1 *out1 = nullptr;
75 resT2 *out2 = nullptr;
76 std::size_t nelems_;
77
78public:
79 UnaryTwoOutputsContigFunctor(const argT *inp,
80 resT1 *res1,
81 resT2 *res2,
82 const std::size_t n_elems)
83 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
84 {
85 }
86
87 void operator()(sycl::nd_item<1> ndit) const
88 {
89 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
90 UnaryTwoOutputsOpT op{};
91 /* Each work-item processes vec_sz elements, contiguous in memory */
92 /* NOTE: work-group size must be divisible by sub-group size */
93
94 if constexpr (enable_sg_loadstore &&
95 UnaryTwoOutputsOpT::is_constant::value) {
96 // value of operator is known to be a known constant
97 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
98 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
99
100 auto sg = ndit.get_sub_group();
101 const std::uint16_t sgSize = sg.get_max_local_range()[0];
102
103 const std::size_t base =
104 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
105 sg.get_group_id()[0] * sgSize);
106 if (base + elems_per_wi * sgSize < nelems_) {
107 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
108 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
109#pragma unroll
110 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
111 const std::size_t offset = base + it * sgSize;
112 auto out1_multi_ptr = sycl::address_space_cast<
113 sycl::access::address_space::global_space,
114 sycl::access::decorated::yes>(&out1[offset]);
115 auto out2_multi_ptr = sycl::address_space_cast<
116 sycl::access::address_space::global_space,
117 sycl::access::decorated::yes>(&out2[offset]);
118
119 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
120 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
121 }
122 }
123 else {
124 const std::size_t lane_id = sg.get_local_id()[0];
125 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
126 out1[k] = const_val1;
127 out2[k] = const_val2;
128 }
129 }
130 }
131 else if constexpr (enable_sg_loadstore &&
132 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
133 UnaryTwoOutputsOpT::supports_vec::value &&
134 (vec_sz > 1)) {
135 auto sg = ndit.get_sub_group();
136 const std::uint16_t sgSize = sg.get_max_local_range()[0];
137
138 const std::size_t base =
139 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
140 sg.get_group_id()[0] * sgSize);
141 if (base + elems_per_wi * sgSize < nelems_) {
142#pragma unroll
143 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
144 const std::size_t offset = base + it * sgSize;
145 auto in_multi_ptr = sycl::address_space_cast<
146 sycl::access::address_space::global_space,
147 sycl::access::decorated::yes>(&in[offset]);
148 auto out1_multi_ptr = sycl::address_space_cast<
149 sycl::access::address_space::global_space,
150 sycl::access::decorated::yes>(&out1[offset]);
151 auto out2_multi_ptr = sycl::address_space_cast<
152 sycl::access::address_space::global_space,
153 sycl::access::decorated::yes>(&out2[offset]);
154
155 const sycl::vec<argT, vec_sz> x =
156 sub_group_load<vec_sz>(sg, in_multi_ptr);
157 sycl::vec<resT2, vec_sz> res2_vec = {};
158 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
159 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
160 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
161 }
162 }
163 else {
164 const std::size_t lane_id = sg.get_local_id()[0];
165 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
166 // scalar call
167 out1[k] = op(in[k], out2[k]);
168 }
169 }
170 }
171 else if constexpr (enable_sg_loadstore &&
172 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
173 std::is_same_v<resT1, argT>) {
174 // default: use scalar-value function
175
176 auto sg = ndit.get_sub_group();
177 const std::uint16_t sgSize = sg.get_max_local_range()[0];
178 const std::size_t base =
179 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
180 sg.get_group_id()[0] * sgSize);
181
182 if (base + elems_per_wi * sgSize < nelems_) {
183#pragma unroll
184 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
185 const std::size_t offset = base + it * sgSize;
186 auto in_multi_ptr = sycl::address_space_cast<
187 sycl::access::address_space::global_space,
188 sycl::access::decorated::yes>(&in[offset]);
189 auto out1_multi_ptr = sycl::address_space_cast<
190 sycl::access::address_space::global_space,
191 sycl::access::decorated::yes>(&out1[offset]);
192 auto out2_multi_ptr = sycl::address_space_cast<
193 sycl::access::address_space::global_space,
194 sycl::access::decorated::yes>(&out2[offset]);
195
196 sycl::vec<argT, vec_sz> arg_vec =
197 sub_group_load<vec_sz>(sg, in_multi_ptr);
198 sycl::vec<resT2, vec_sz> res2_vec = {};
199#pragma unroll
200 for (std::uint32_t k = 0; k < vec_sz; ++k) {
201 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
202 }
203 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
204 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
205 }
206 }
207 else {
208 const std::size_t lane_id = sg.get_local_id()[0];
209 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
210 out1[k] = op(in[k], out2[k]);
211 }
212 }
213 }
214 else if constexpr (enable_sg_loadstore &&
215 UnaryTwoOutputsOpT::supports_sg_loadstore::value) {
216 // default: use scalar-value function
217
218 auto sg = ndit.get_sub_group();
219 const std::uint16_t sgSize = sg.get_max_local_range()[0];
220 const std::size_t base =
221 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
222 sg.get_group_id()[0] * sgSize);
223
224 if (base + elems_per_wi * sgSize < nelems_) {
225#pragma unroll
226 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
227 const std::size_t offset = base + it * sgSize;
228 auto in_multi_ptr = sycl::address_space_cast<
229 sycl::access::address_space::global_space,
230 sycl::access::decorated::yes>(&in[offset]);
231 auto out1_multi_ptr = sycl::address_space_cast<
232 sycl::access::address_space::global_space,
233 sycl::access::decorated::yes>(&out1[offset]);
234 auto out2_multi_ptr = sycl::address_space_cast<
235 sycl::access::address_space::global_space,
236 sycl::access::decorated::yes>(&out2[offset]);
237
238 const sycl::vec<argT, vec_sz> arg_vec =
239 sub_group_load<vec_sz>(sg, in_multi_ptr);
240 sycl::vec<resT1, vec_sz> res1_vec = {};
241 sycl::vec<resT2, vec_sz> res2_vec = {};
242#pragma unroll
243 for (std::uint8_t k = 0; k < vec_sz; ++k) {
244 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
245 }
246 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
247 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
248 }
249 }
250 else {
251 const std::size_t lane_id = sg.get_local_id()[0];
252 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
253 out1[k] = op(in[k], out2[k]);
254 }
255 }
256 }
257 else {
258 const std::uint16_t sgSize =
259 ndit.get_sub_group().get_local_range()[0];
260 const std::size_t gid = ndit.get_global_linear_id();
261 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
262
263 const std::size_t start =
264 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
265 const std::size_t end = std::min(nelems_, start + elems_per_sg);
266 for (std::size_t offset = start; offset < end; offset += sgSize) {
267 out1[offset] = op(in[offset], out2[offset]);
268 }
269 }
270 }
271};
272
280template <typename argT,
281 typename resT1,
282 typename resT2,
283 typename IndexerT,
284 typename UnaryTwoOutputsOpT>
286{
287private:
288 const argT *inp_ = nullptr;
289 resT1 *res1_ = nullptr;
290 resT2 *res2_ = nullptr;
291 IndexerT inp_out_indexer_;
292
293public:
294 UnaryTwoOutputsStridedFunctor(const argT *inp_p,
295 resT1 *res1_p,
296 resT2 *res2_p,
297 const IndexerT &inp_out_indexer)
298 : inp_(inp_p), res1_(res1_p), res2_(res2_p),
299 inp_out_indexer_(inp_out_indexer)
300 {
301 }
302
303 void operator()(sycl::id<1> wid) const
304 {
305 const auto &offsets_ = inp_out_indexer_(wid.get(0));
306 const ssize_t &inp_offset = offsets_.get_first_offset();
307 const ssize_t &res1_offset = offsets_.get_second_offset();
308 const ssize_t &res2_offset = offsets_.get_third_offset();
309
310 UnaryTwoOutputsOpT op{};
311
312 res1_[res1_offset] = op(inp_[inp_offset], res2_[res2_offset]);
313 }
314};
315
323template <typename argT1,
324 typename argT2,
325 typename resT1,
326 typename resT2,
327 typename BinaryOperatorT,
328 std::uint8_t vec_sz = 4u,
329 std::uint8_t n_vecs = 2u,
330 bool enable_sg_loadstore = true>
332{
333private:
334 const argT1 *in1 = nullptr;
335 const argT2 *in2 = nullptr;
336 resT1 *out1 = nullptr;
337 resT2 *out2 = nullptr;
338 std::size_t nelems_;
339
340public:
341 BinaryTwoOutputsContigFunctor(const argT1 *inp1,
342 const argT2 *inp2,
343 resT1 *res1,
344 resT2 *res2,
345 std::size_t n_elems)
346 : in1(inp1), in2(inp2), out1(res1), out2(res2), nelems_(n_elems)
347 {
348 }
349
350 void operator()(sycl::nd_item<1> ndit) const
351 {
352 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
353 BinaryOperatorT op{};
354 /* Each work-item processes vec_sz elements, contiguous in memory */
355 /* NOTE: work-group size must be divisible by sub-group size */
356
357 if constexpr (enable_sg_loadstore &&
358 BinaryOperatorT::supports_sg_loadstore::value &&
359 BinaryOperatorT::supports_vec::value && (vec_sz > 1)) {
360 auto sg = ndit.get_sub_group();
361 std::uint16_t sgSize = sg.get_max_local_range()[0];
362
363 const std::size_t base =
364 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
365 sg.get_group_id()[0] * sgSize);
366
367 if (base + elems_per_wi * sgSize < nelems_) {
368 sycl::vec<resT1, vec_sz> res1_vec;
369 sycl::vec<resT2, vec_sz> res2_vec;
370
371#pragma unroll
372 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
373 std::size_t offset = base + it * sgSize;
374 auto in1_multi_ptr = sycl::address_space_cast<
375 sycl::access::address_space::global_space,
376 sycl::access::decorated::yes>(&in1[offset]);
377 auto in2_multi_ptr = sycl::address_space_cast<
378 sycl::access::address_space::global_space,
379 sycl::access::decorated::yes>(&in2[offset]);
380 auto out1_multi_ptr = sycl::address_space_cast<
381 sycl::access::address_space::global_space,
382 sycl::access::decorated::yes>(&out1[offset]);
383 auto out2_multi_ptr = sycl::address_space_cast<
384 sycl::access::address_space::global_space,
385 sycl::access::decorated::yes>(&out2[offset]);
386
387 const sycl::vec<argT1, vec_sz> arg1_vec =
388 sub_group_load<vec_sz>(sg, in1_multi_ptr);
389 const sycl::vec<argT2, vec_sz> arg2_vec =
390 sub_group_load<vec_sz>(sg, in2_multi_ptr);
391 res1_vec = op(arg1_vec, arg2_vec, res2_vec);
392 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
393 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
394 }
395 }
396 else {
397 const std::size_t lane_id = sg.get_local_id()[0];
398 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
399 out1[k] = op(in1[k], in2[k], out2[k]);
400 }
401 }
402 }
403 else if constexpr (enable_sg_loadstore &&
404 BinaryOperatorT::supports_sg_loadstore::value) {
405 auto sg = ndit.get_sub_group();
406 const std::uint16_t sgSize = sg.get_max_local_range()[0];
407
408 const std::size_t base =
409 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
410 sg.get_group_id()[0] * sgSize);
411
412 if (base + elems_per_wi * sgSize < nelems_) {
413#pragma unroll
414 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
415 const std::size_t offset = base + it * sgSize;
416 auto in1_multi_ptr = sycl::address_space_cast<
417 sycl::access::address_space::global_space,
418 sycl::access::decorated::yes>(&in1[offset]);
419 auto in2_multi_ptr = sycl::address_space_cast<
420 sycl::access::address_space::global_space,
421 sycl::access::decorated::yes>(&in2[offset]);
422 auto out1_multi_ptr = sycl::address_space_cast<
423 sycl::access::address_space::global_space,
424 sycl::access::decorated::yes>(&out1[offset]);
425 auto out2_multi_ptr = sycl::address_space_cast<
426 sycl::access::address_space::global_space,
427 sycl::access::decorated::yes>(&out2[offset]);
428
429 const sycl::vec<argT1, vec_sz> arg1_vec =
430 sub_group_load<vec_sz>(sg, in1_multi_ptr);
431 const sycl::vec<argT2, vec_sz> arg2_vec =
432 sub_group_load<vec_sz>(sg, in2_multi_ptr);
433
434 sycl::vec<resT1, vec_sz> res1_vec;
435 sycl::vec<resT2, vec_sz> res2_vec;
436#pragma unroll
437 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
438 res1_vec[vec_id] =
439 op(arg1_vec[vec_id], arg2_vec[vec_id],
440 res2_vec[vec_id]);
441 }
442 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
443 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
444 }
445 }
446 else {
447 const std::size_t lane_id = sg.get_local_id()[0];
448 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
449 out1[k] = op(in1[k], in2[k], out2[k]);
450 }
451 }
452 }
453 else {
454 const std::size_t sgSize =
455 ndit.get_sub_group().get_local_range()[0];
456 const std::size_t gid = ndit.get_global_linear_id();
457 const std::size_t elems_per_sg = sgSize * elems_per_wi;
458
459 const std::size_t start =
460 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
461 const std::size_t end = std::min(nelems_, start + elems_per_sg);
462 for (std::size_t offset = start; offset < end; offset += sgSize) {
463 out1[offset] = op(in1[offset], in2[offset], out2[offset]);
464 }
465 }
466 }
467};
468
476template <typename argT1,
477 typename argT2,
478 typename resT1,
479 typename resT2,
480 typename FourOffsets_IndexerT,
481 typename BinaryOperatorT>
483{
484private:
485 const argT1 *in1 = nullptr;
486 const argT2 *in2 = nullptr;
487 resT1 *out1 = nullptr;
488 resT2 *out2 = nullptr;
489 FourOffsets_IndexerT four_offsets_indexer_;
490
491public:
492 BinaryTwoOutputsStridedFunctor(const argT1 *inp1_tp,
493 const argT2 *inp2_tp,
494 resT1 *res1_tp,
495 resT2 *res2_tp,
496 const FourOffsets_IndexerT &inps_res_indexer)
497 : in1(inp1_tp), in2(inp2_tp), out1(res1_tp), out2(res2_tp),
498 four_offsets_indexer_(inps_res_indexer)
499 {
500 }
501
502 void operator()(sycl::id<1> wid) const
503 {
504 const auto &four_offsets_ =
505 four_offsets_indexer_(static_cast<ssize_t>(wid.get(0)));
506
507 const auto &inp1_offset = four_offsets_.get_first_offset();
508 const auto &inp2_offset = four_offsets_.get_second_offset();
509 const auto &out1_offset = four_offsets_.get_third_offset();
510 const auto &out2_offset = four_offsets_.get_fourth_offset();
511
512 BinaryOperatorT op{};
513 out1[out1_offset] =
514 op(in1[inp1_offset], in2[inp2_offset], out2[out2_offset]);
515 }
516};
517
525template <typename argTy,
526 template <typename T> class UnaryTwoOutputsType,
527 template <typename A,
528 typename R1,
529 typename R2,
530 std::uint8_t vs,
531 std::uint8_t nv,
532 bool enable> class UnaryTwoOutputsContigFunctorT,
533 template <typename A,
534 typename R1,
535 typename R2,
536 std::uint8_t vs,
537 std::uint8_t nv> class kernel_name,
538 std::uint8_t vec_sz = 4u,
539 std::uint8_t n_vecs = 2u>
540sycl::event
541 unary_two_outputs_contig_impl(sycl::queue &exec_q,
542 std::size_t nelems,
543 const char *arg_p,
544 char *res1_p,
545 char *res2_p,
546 const std::vector<sycl::event> &depends = {})
547{
548 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
549 const std::size_t n_work_items_needed = nelems / elems_per_wi;
550 const std::size_t lws =
551 select_lws(exec_q.get_device(), n_work_items_needed);
552
553 const std::size_t n_groups =
554 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
555 const auto gws_range = sycl::range<1>(n_groups * lws);
556 const auto lws_range = sycl::range<1>(lws);
557
558 using resTy1 = typename UnaryTwoOutputsType<argTy>::value_type1;
559 using resTy2 = typename UnaryTwoOutputsType<argTy>::value_type2;
560 using BaseKernelName = kernel_name<argTy, resTy1, resTy2, vec_sz, n_vecs>;
561
562 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
563 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p);
564 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p);
565
566 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
567 cgh.depends_on(depends);
568
569 if (is_aligned<required_alignment>(arg_p) &&
570 is_aligned<required_alignment>(res1_p) &&
571 is_aligned<required_alignment>(res2_p)) {
572 static constexpr bool enable_sg_loadstore = true;
573 using KernelName = BaseKernelName;
574 using Impl =
575 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
576 n_vecs, enable_sg_loadstore>;
577
578 cgh.parallel_for<KernelName>(
579 sycl::nd_range<1>(gws_range, lws_range),
580 Impl(arg_tp, res1_tp, res2_tp, nelems));
581 }
582 else {
583 static constexpr bool disable_sg_loadstore = false;
584 using KernelName =
585 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
586 using Impl =
587 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
588 n_vecs, disable_sg_loadstore>;
589
590 cgh.parallel_for<KernelName>(
591 sycl::nd_range<1>(gws_range, lws_range),
592 Impl(arg_tp, res1_tp, res2_tp, nelems));
593 }
594 });
595
596 return comp_ev;
597}
598
606template <typename argTy,
607 template <typename T> class UnaryTwoOutputsType,
608 template <typename A,
609 typename R1,
610 typename R2,
611 typename I> class UnaryTwoOutputsStridedFunctorT,
612 template <typename A,
613 typename R1,
614 typename R2,
615 typename I> class kernel_name>
616sycl::event unary_two_outputs_strided_impl(
617 sycl::queue &exec_q,
618 std::size_t nelems,
619 int nd,
620 const ssize_t *shape_and_strides,
621 const char *arg_p,
622 ssize_t arg_offset,
623 char *res1_p,
624 ssize_t res1_offset,
625 char *res2_p,
626 ssize_t res2_offset,
627 const std::vector<sycl::event> &depends,
628 const std::vector<sycl::event> &additional_depends)
629{
630 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
631 cgh.depends_on(depends);
632 cgh.depends_on(additional_depends);
633
634 using res1Ty = typename UnaryTwoOutputsType<argTy>::value_type1;
635 using res2Ty = typename UnaryTwoOutputsType<argTy>::value_type2;
636 using IndexerT =
637 typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
638
639 const IndexerT indexer{nd, arg_offset, res1_offset, res2_offset,
640 shape_and_strides};
641
642 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
643 res1Ty *res1_tp = reinterpret_cast<res1Ty *>(res1_p);
644 res2Ty *res2_tp = reinterpret_cast<res2Ty *>(res2_p);
645
646 using Impl =
647 UnaryTwoOutputsStridedFunctorT<argTy, res1Ty, res2Ty, IndexerT>;
648
649 cgh.parallel_for<kernel_name<argTy, res1Ty, res2Ty, IndexerT>>(
650 {nelems}, Impl(arg_tp, res1_tp, res2_tp, indexer));
651 });
652 return comp_ev;
653}
654
662template <
663 typename argTy1,
664 typename argTy2,
665 template <typename T1, typename T2> class BinaryTwoOutputsType,
666 template <typename T1,
667 typename T2,
668 typename T3,
669 typename T4,
670 std::uint8_t vs,
671 std::uint8_t nv,
672 bool enable_sg_loadstore> class BinaryTwoOutputsContigFunctorT,
673 template <typename T1,
674 typename T2,
675 typename T3,
676 typename T4,
677 std::uint8_t vs,
678 std::uint8_t nv> class kernel_name,
679 std::uint8_t vec_sz = 4u,
680 std::uint8_t n_vecs = 2u>
681sycl::event
682 binary_two_outputs_contig_impl(sycl::queue &exec_q,
683 std::size_t nelems,
684 const char *arg1_p,
685 ssize_t arg1_offset,
686 const char *arg2_p,
687 ssize_t arg2_offset,
688 char *res1_p,
689 ssize_t res1_offset,
690 char *res2_p,
691 ssize_t res2_offset,
692 const std::vector<sycl::event> &depends = {})
693{
694 const std::size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
695 const std::size_t lws =
696 select_lws(exec_q.get_device(), n_work_items_needed);
697
698 const std::size_t n_groups =
699 ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
700 const auto gws_range = sycl::range<1>(n_groups * lws);
701 const auto lws_range = sycl::range<1>(lws);
702
703 using resTy1 = typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
704 using resTy2 = typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
705 using BaseKernelName =
706 kernel_name<argTy1, argTy2, resTy1, resTy2, vec_sz, n_vecs>;
707
708 const argTy1 *arg1_tp =
709 reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
710 const argTy2 *arg2_tp =
711 reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
712 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p) + res1_offset;
713 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p) + res2_offset;
714
715 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
716 cgh.depends_on(depends);
717
718 if (is_aligned<required_alignment>(arg1_tp) &&
719 is_aligned<required_alignment>(arg2_tp) &&
720 is_aligned<required_alignment>(res1_tp) &&
721 is_aligned<required_alignment>(res2_tp)) {
722 static constexpr bool enable_sg_loadstore = true;
723 using KernelName = BaseKernelName;
724 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
725 resTy2, vec_sz, n_vecs,
726 enable_sg_loadstore>;
727
728 cgh.parallel_for<KernelName>(
729 sycl::nd_range<1>(gws_range, lws_range),
730 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
731 }
732 else {
733 static constexpr bool disable_sg_loadstore = false;
734 using KernelName =
735 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
736 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
737 resTy2, vec_sz, n_vecs,
738 disable_sg_loadstore>;
739
740 cgh.parallel_for<KernelName>(
741 sycl::nd_range<1>(gws_range, lws_range),
742 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
743 }
744 });
745 return comp_ev;
746}
747
755template <typename argTy1,
756 typename argTy2,
757 template <typename T1, typename T2> class BinaryTwoOutputsType,
758 template <typename T1,
759 typename T2,
760 typename T3,
761 typename T4,
762 typename IndT> class BinaryTwoOutputsStridedFunctorT,
763 template <typename T1,
764 typename T2,
765 typename T3,
766 typename T4,
767 typename IndT> class kernel_name>
768sycl::event binary_two_outputs_strided_impl(
769 sycl::queue &exec_q,
770 std::size_t nelems,
771 int nd,
772 const ssize_t *shape_and_strides,
773 const char *arg1_p,
774 ssize_t arg1_offset,
775 const char *arg2_p,
776 ssize_t arg2_offset,
777 char *res1_p,
778 ssize_t res1_offset,
779 char *res2_p,
780 ssize_t res2_offset,
781 const std::vector<sycl::event> &depends,
782 const std::vector<sycl::event> &additional_depends)
783{
784 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
785 cgh.depends_on(depends);
786 cgh.depends_on(additional_depends);
787
788 using resTy1 =
789 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
790 using resTy2 =
791 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
792
793 using IndexerT =
794 typename dpctl::tensor::offset_utils::FourOffsets_StridedIndexer;
795
796 const IndexerT indexer{nd, arg1_offset, arg2_offset,
797 res1_offset, res2_offset, shape_and_strides};
798
799 const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
800 const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
801 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p);
802 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p);
803
804 using Impl = BinaryTwoOutputsStridedFunctorT<argTy1, argTy2, resTy1,
805 resTy2, IndexerT>;
806
807 cgh.parallel_for<kernel_name<argTy1, argTy2, resTy1, resTy2, IndexerT>>(
808 {nelems}, Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, indexer));
809 });
810 return comp_ev;
811}
812
813// Typedefs for function pointers
814
815typedef sycl::event (*unary_two_outputs_contig_impl_fn_ptr_t)(
816 sycl::queue &,
817 std::size_t,
818 const char *,
819 char *,
820 char *,
821 const std::vector<sycl::event> &);
822
823typedef sycl::event (*unary_two_outputs_strided_impl_fn_ptr_t)(
824 sycl::queue &,
825 std::size_t,
826 int,
827 const ssize_t *,
828 const char *,
829 ssize_t,
830 char *,
831 ssize_t,
832 char *,
833 ssize_t,
834 const std::vector<sycl::event> &,
835 const std::vector<sycl::event> &);
836
837typedef sycl::event (*binary_two_outputs_contig_impl_fn_ptr_t)(
838 sycl::queue &,
839 std::size_t,
840 const char *,
841 ssize_t,
842 const char *,
843 ssize_t,
844 char *,
845 ssize_t,
846 char *,
847 ssize_t,
848 const std::vector<sycl::event> &);
849
850typedef sycl::event (*binary_two_outputs_strided_impl_fn_ptr_t)(
851 sycl::queue &,
852 std::size_t,
853 int,
854 const ssize_t *,
855 const char *,
856 ssize_t,
857 const char *,
858 ssize_t,
859 char *,
860 ssize_t,
861 char *,
862 ssize_t,
863 const std::vector<sycl::event> &,
864 const std::vector<sycl::event> &);
865
866} // namespace dpnp::extensions::py_internal::elementwise_common
Functor for evaluation of a binary function with two output arrays on contiguous arrays.
Definition common.hpp:332
Functor for evaluation of a binary function with two output arrays on strided data.
Definition common.hpp:483
Functor for evaluation of a unary function with two output arrays on contiguous arrays.
Definition common.hpp:71
Functor for evaluation of a unary function with two output arrays on strided data.
Definition common.hpp:286