DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <algorithm>
32#include <cstddef>
33#include <cstdint>
34#include <utility>
35#include <vector>
36
37#include <sycl/sycl.hpp>
38
39// dpnp tensor headers
40#include "kernels/alignment.hpp"
41#include "kernels/elementwise_functions/common.hpp"
42#include "utils/sycl_utils.hpp"
43
44namespace dpnp::extensions::py_internal::elementwise_common
45{
46using dpnp::tensor::kernels::alignment_utils::disabled_sg_loadstore_wrapper_krn;
47using dpnp::tensor::kernels::alignment_utils::is_aligned;
48using dpnp::tensor::kernels::alignment_utils::required_alignment;
49
50using dpnp::tensor::kernels::elementwise_common::select_lws;
51
52using dpnp::tensor::sycl_utils::sub_group_load;
53using dpnp::tensor::sycl_utils::sub_group_store;
54
62template <typename argT,
63 typename resT1,
64 typename resT2,
65 typename UnaryTwoOutputsOpT,
66 std::uint8_t vec_sz = 4u,
67 std::uint8_t n_vecs = 2u,
68 bool enable_sg_loadstore = true>
70{
71private:
72 const argT *in = nullptr;
73 resT1 *out1 = nullptr;
74 resT2 *out2 = nullptr;
75 std::size_t nelems_;
76
77public:
78 UnaryTwoOutputsContigFunctor(const argT *inp,
79 resT1 *res1,
80 resT2 *res2,
81 const std::size_t n_elems)
82 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
83 {
84 }
85
86 void operator()(sycl::nd_item<1> ndit) const
87 {
88 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
89 UnaryTwoOutputsOpT op{};
90 /* Each work-item processes vec_sz elements, contiguous in memory */
91 /* NOTE: work-group size must be divisible by sub-group size */
92
93 if constexpr (enable_sg_loadstore &&
94 UnaryTwoOutputsOpT::is_constant::value) {
95 // value of operator is known to be a known constant
96 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
97 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
98
99 auto sg = ndit.get_sub_group();
100 const std::uint16_t sgSize = sg.get_max_local_range()[0];
101
102 const std::size_t base =
103 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
104 sg.get_group_id()[0] * sgSize);
105 if (base + elems_per_wi * sgSize < nelems_) {
106 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
107 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
108#pragma unroll
109 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
110 const std::size_t offset = base + it * sgSize;
111 auto out1_multi_ptr = sycl::address_space_cast<
112 sycl::access::address_space::global_space,
113 sycl::access::decorated::yes>(&out1[offset]);
114 auto out2_multi_ptr = sycl::address_space_cast<
115 sycl::access::address_space::global_space,
116 sycl::access::decorated::yes>(&out2[offset]);
117
118 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
119 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
120 }
121 }
122 else {
123 const std::size_t lane_id = sg.get_local_id()[0];
124 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
125 out1[k] = const_val1;
126 out2[k] = const_val2;
127 }
128 }
129 }
130 else if constexpr (enable_sg_loadstore &&
131 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
132 UnaryTwoOutputsOpT::supports_vec::value &&
133 (vec_sz > 1)) {
134 auto sg = ndit.get_sub_group();
135 const std::uint16_t sgSize = sg.get_max_local_range()[0];
136
137 const std::size_t base =
138 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
139 sg.get_group_id()[0] * sgSize);
140 if (base + elems_per_wi * sgSize < nelems_) {
141#pragma unroll
142 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
143 const std::size_t offset = base + it * sgSize;
144 auto in_multi_ptr = sycl::address_space_cast<
145 sycl::access::address_space::global_space,
146 sycl::access::decorated::yes>(&in[offset]);
147 auto out1_multi_ptr = sycl::address_space_cast<
148 sycl::access::address_space::global_space,
149 sycl::access::decorated::yes>(&out1[offset]);
150 auto out2_multi_ptr = sycl::address_space_cast<
151 sycl::access::address_space::global_space,
152 sycl::access::decorated::yes>(&out2[offset]);
153
154 const sycl::vec<argT, vec_sz> x =
155 sub_group_load<vec_sz>(sg, in_multi_ptr);
156 sycl::vec<resT2, vec_sz> res2_vec = {};
157 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
158 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
159 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
160 }
161 }
162 else {
163 const std::size_t lane_id = sg.get_local_id()[0];
164 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
165 // scalar call
166 out1[k] = op(in[k], out2[k]);
167 }
168 }
169 }
170 else if constexpr (enable_sg_loadstore &&
171 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
172 std::is_same_v<resT1, argT>) {
173 // default: use scalar-value function
174
175 auto sg = ndit.get_sub_group();
176 const std::uint16_t sgSize = sg.get_max_local_range()[0];
177 const std::size_t base =
178 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
179 sg.get_group_id()[0] * sgSize);
180
181 if (base + elems_per_wi * sgSize < nelems_) {
182#pragma unroll
183 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
184 const std::size_t offset = base + it * sgSize;
185 auto in_multi_ptr = sycl::address_space_cast<
186 sycl::access::address_space::global_space,
187 sycl::access::decorated::yes>(&in[offset]);
188 auto out1_multi_ptr = sycl::address_space_cast<
189 sycl::access::address_space::global_space,
190 sycl::access::decorated::yes>(&out1[offset]);
191 auto out2_multi_ptr = sycl::address_space_cast<
192 sycl::access::address_space::global_space,
193 sycl::access::decorated::yes>(&out2[offset]);
194
195 sycl::vec<argT, vec_sz> arg_vec =
196 sub_group_load<vec_sz>(sg, in_multi_ptr);
197 sycl::vec<resT2, vec_sz> res2_vec = {};
198#pragma unroll
199 for (std::uint32_t k = 0; k < vec_sz; ++k) {
200 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
201 }
202 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
203 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
204 }
205 }
206 else {
207 const std::size_t lane_id = sg.get_local_id()[0];
208 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
209 out1[k] = op(in[k], out2[k]);
210 }
211 }
212 }
213 else if constexpr (enable_sg_loadstore &&
214 UnaryTwoOutputsOpT::supports_sg_loadstore::value) {
215 // default: use scalar-value function
216
217 auto sg = ndit.get_sub_group();
218 const std::uint16_t sgSize = sg.get_max_local_range()[0];
219 const std::size_t base =
220 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
221 sg.get_group_id()[0] * sgSize);
222
223 if (base + elems_per_wi * sgSize < nelems_) {
224#pragma unroll
225 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
226 const std::size_t offset = base + it * sgSize;
227 auto in_multi_ptr = sycl::address_space_cast<
228 sycl::access::address_space::global_space,
229 sycl::access::decorated::yes>(&in[offset]);
230 auto out1_multi_ptr = sycl::address_space_cast<
231 sycl::access::address_space::global_space,
232 sycl::access::decorated::yes>(&out1[offset]);
233 auto out2_multi_ptr = sycl::address_space_cast<
234 sycl::access::address_space::global_space,
235 sycl::access::decorated::yes>(&out2[offset]);
236
237 const sycl::vec<argT, vec_sz> arg_vec =
238 sub_group_load<vec_sz>(sg, in_multi_ptr);
239 sycl::vec<resT1, vec_sz> res1_vec = {};
240 sycl::vec<resT2, vec_sz> res2_vec = {};
241#pragma unroll
242 for (std::uint8_t k = 0; k < vec_sz; ++k) {
243 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
244 }
245 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
246 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
247 }
248 }
249 else {
250 const std::size_t lane_id = sg.get_local_id()[0];
251 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
252 out1[k] = op(in[k], out2[k]);
253 }
254 }
255 }
256 else {
257 const std::uint16_t sgSize =
258 ndit.get_sub_group().get_local_range()[0];
259 const std::size_t gid = ndit.get_global_linear_id();
260 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
261
262 const std::size_t start =
263 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
264 const std::size_t end = std::min(nelems_, start + elems_per_sg);
265 for (std::size_t offset = start; offset < end; offset += sgSize) {
266 out1[offset] = op(in[offset], out2[offset]);
267 }
268 }
269 }
270};
271
279template <typename argT,
280 typename resT1,
281 typename resT2,
282 typename IndexerT,
283 typename UnaryTwoOutputsOpT>
285{
286private:
287 const argT *inp_ = nullptr;
288 resT1 *res1_ = nullptr;
289 resT2 *res2_ = nullptr;
290 IndexerT inp_out_indexer_;
291
292public:
293 UnaryTwoOutputsStridedFunctor(const argT *inp_p,
294 resT1 *res1_p,
295 resT2 *res2_p,
296 const IndexerT &inp_out_indexer)
297 : inp_(inp_p), res1_(res1_p), res2_(res2_p),
298 inp_out_indexer_(inp_out_indexer)
299 {
300 }
301
302 void operator()(sycl::id<1> wid) const
303 {
304 const auto &offsets_ = inp_out_indexer_(wid.get(0));
305 const ssize_t &inp_offset = offsets_.get_first_offset();
306 const ssize_t &res1_offset = offsets_.get_second_offset();
307 const ssize_t &res2_offset = offsets_.get_third_offset();
308
309 UnaryTwoOutputsOpT op{};
310
311 res1_[res1_offset] = op(inp_[inp_offset], res2_[res2_offset]);
312 }
313};
314
322template <typename argT1,
323 typename argT2,
324 typename resT1,
325 typename resT2,
326 typename BinaryOperatorT,
327 std::uint8_t vec_sz = 4u,
328 std::uint8_t n_vecs = 2u,
329 bool enable_sg_loadstore = true>
331{
332private:
333 const argT1 *in1 = nullptr;
334 const argT2 *in2 = nullptr;
335 resT1 *out1 = nullptr;
336 resT2 *out2 = nullptr;
337 std::size_t nelems_;
338
339public:
340 BinaryTwoOutputsContigFunctor(const argT1 *inp1,
341 const argT2 *inp2,
342 resT1 *res1,
343 resT2 *res2,
344 std::size_t n_elems)
345 : in1(inp1), in2(inp2), out1(res1), out2(res2), nelems_(n_elems)
346 {
347 }
348
349 void operator()(sycl::nd_item<1> ndit) const
350 {
351 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
352 BinaryOperatorT op{};
353 /* Each work-item processes vec_sz elements, contiguous in memory */
354 /* NOTE: work-group size must be divisible by sub-group size */
355
356 if constexpr (enable_sg_loadstore &&
357 BinaryOperatorT::supports_sg_loadstore::value &&
358 BinaryOperatorT::supports_vec::value && (vec_sz > 1)) {
359 auto sg = ndit.get_sub_group();
360 std::uint16_t sgSize = sg.get_max_local_range()[0];
361
362 const std::size_t base =
363 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
364 sg.get_group_id()[0] * sgSize);
365
366 if (base + elems_per_wi * sgSize < nelems_) {
367 sycl::vec<resT1, vec_sz> res1_vec;
368 sycl::vec<resT2, vec_sz> res2_vec;
369
370#pragma unroll
371 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
372 std::size_t offset = base + it * sgSize;
373 auto in1_multi_ptr = sycl::address_space_cast<
374 sycl::access::address_space::global_space,
375 sycl::access::decorated::yes>(&in1[offset]);
376 auto in2_multi_ptr = sycl::address_space_cast<
377 sycl::access::address_space::global_space,
378 sycl::access::decorated::yes>(&in2[offset]);
379 auto out1_multi_ptr = sycl::address_space_cast<
380 sycl::access::address_space::global_space,
381 sycl::access::decorated::yes>(&out1[offset]);
382 auto out2_multi_ptr = sycl::address_space_cast<
383 sycl::access::address_space::global_space,
384 sycl::access::decorated::yes>(&out2[offset]);
385
386 const sycl::vec<argT1, vec_sz> arg1_vec =
387 sub_group_load<vec_sz>(sg, in1_multi_ptr);
388 const sycl::vec<argT2, vec_sz> arg2_vec =
389 sub_group_load<vec_sz>(sg, in2_multi_ptr);
390 res1_vec = op(arg1_vec, arg2_vec, res2_vec);
391 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
392 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
393 }
394 }
395 else {
396 const std::size_t lane_id = sg.get_local_id()[0];
397 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
398 out1[k] = op(in1[k], in2[k], out2[k]);
399 }
400 }
401 }
402 else if constexpr (enable_sg_loadstore &&
403 BinaryOperatorT::supports_sg_loadstore::value) {
404 auto sg = ndit.get_sub_group();
405 const std::uint16_t sgSize = sg.get_max_local_range()[0];
406
407 const std::size_t base =
408 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
409 sg.get_group_id()[0] * sgSize);
410
411 if (base + elems_per_wi * sgSize < nelems_) {
412#pragma unroll
413 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
414 const std::size_t offset = base + it * sgSize;
415 auto in1_multi_ptr = sycl::address_space_cast<
416 sycl::access::address_space::global_space,
417 sycl::access::decorated::yes>(&in1[offset]);
418 auto in2_multi_ptr = sycl::address_space_cast<
419 sycl::access::address_space::global_space,
420 sycl::access::decorated::yes>(&in2[offset]);
421 auto out1_multi_ptr = sycl::address_space_cast<
422 sycl::access::address_space::global_space,
423 sycl::access::decorated::yes>(&out1[offset]);
424 auto out2_multi_ptr = sycl::address_space_cast<
425 sycl::access::address_space::global_space,
426 sycl::access::decorated::yes>(&out2[offset]);
427
428 const sycl::vec<argT1, vec_sz> arg1_vec =
429 sub_group_load<vec_sz>(sg, in1_multi_ptr);
430 const sycl::vec<argT2, vec_sz> arg2_vec =
431 sub_group_load<vec_sz>(sg, in2_multi_ptr);
432
433 sycl::vec<resT1, vec_sz> res1_vec;
434 sycl::vec<resT2, vec_sz> res2_vec;
435#pragma unroll
436 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
437 res1_vec[vec_id] =
438 op(arg1_vec[vec_id], arg2_vec[vec_id],
439 res2_vec[vec_id]);
440 }
441 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
442 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
443 }
444 }
445 else {
446 const std::size_t lane_id = sg.get_local_id()[0];
447 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
448 out1[k] = op(in1[k], in2[k], out2[k]);
449 }
450 }
451 }
452 else {
453 const std::size_t sgSize =
454 ndit.get_sub_group().get_local_range()[0];
455 const std::size_t gid = ndit.get_global_linear_id();
456 const std::size_t elems_per_sg = sgSize * elems_per_wi;
457
458 const std::size_t start =
459 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
460 const std::size_t end = std::min(nelems_, start + elems_per_sg);
461 for (std::size_t offset = start; offset < end; offset += sgSize) {
462 out1[offset] = op(in1[offset], in2[offset], out2[offset]);
463 }
464 }
465 }
466};
467
475template <typename argT1,
476 typename argT2,
477 typename resT1,
478 typename resT2,
479 typename FourOffsets_IndexerT,
480 typename BinaryOperatorT>
482{
483private:
484 const argT1 *in1 = nullptr;
485 const argT2 *in2 = nullptr;
486 resT1 *out1 = nullptr;
487 resT2 *out2 = nullptr;
488 FourOffsets_IndexerT four_offsets_indexer_;
489
490public:
491 BinaryTwoOutputsStridedFunctor(const argT1 *inp1_tp,
492 const argT2 *inp2_tp,
493 resT1 *res1_tp,
494 resT2 *res2_tp,
495 const FourOffsets_IndexerT &inps_res_indexer)
496 : in1(inp1_tp), in2(inp2_tp), out1(res1_tp), out2(res2_tp),
497 four_offsets_indexer_(inps_res_indexer)
498 {
499 }
500
501 void operator()(sycl::id<1> wid) const
502 {
503 const auto &four_offsets_ =
504 four_offsets_indexer_(static_cast<ssize_t>(wid.get(0)));
505
506 const auto &inp1_offset = four_offsets_.get_first_offset();
507 const auto &inp2_offset = four_offsets_.get_second_offset();
508 const auto &out1_offset = four_offsets_.get_third_offset();
509 const auto &out2_offset = four_offsets_.get_fourth_offset();
510
511 BinaryOperatorT op{};
512 out1[out1_offset] =
513 op(in1[inp1_offset], in2[inp2_offset], out2[out2_offset]);
514 }
515};
516
524template <typename argTy,
525 template <typename T> class UnaryTwoOutputsType,
526 template <typename A,
527 typename R1,
528 typename R2,
529 std::uint8_t vs,
530 std::uint8_t nv,
531 bool enable> class UnaryTwoOutputsContigFunctorT,
532 template <typename A,
533 typename R1,
534 typename R2,
535 std::uint8_t vs,
536 std::uint8_t nv> class kernel_name,
537 std::uint8_t vec_sz = 4u,
538 std::uint8_t n_vecs = 2u>
539sycl::event
540 unary_two_outputs_contig_impl(sycl::queue &exec_q,
541 std::size_t nelems,
542 const char *arg_p,
543 char *res1_p,
544 char *res2_p,
545 const std::vector<sycl::event> &depends = {})
546{
547 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
548 const std::size_t n_work_items_needed = nelems / elems_per_wi;
549 const std::size_t lws =
550 select_lws(exec_q.get_device(), n_work_items_needed);
551
552 const std::size_t n_groups =
553 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
554 const auto gws_range = sycl::range<1>(n_groups * lws);
555 const auto lws_range = sycl::range<1>(lws);
556
557 using resTy1 = typename UnaryTwoOutputsType<argTy>::value_type1;
558 using resTy2 = typename UnaryTwoOutputsType<argTy>::value_type2;
559 using BaseKernelName = kernel_name<argTy, resTy1, resTy2, vec_sz, n_vecs>;
560
561 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
562 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p);
563 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p);
564
565 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
566 cgh.depends_on(depends);
567
568 if (is_aligned<required_alignment>(arg_p) &&
569 is_aligned<required_alignment>(res1_p) &&
570 is_aligned<required_alignment>(res2_p)) {
571 static constexpr bool enable_sg_loadstore = true;
572 using KernelName = BaseKernelName;
573 using Impl =
574 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
575 n_vecs, enable_sg_loadstore>;
576
577 cgh.parallel_for<KernelName>(
578 sycl::nd_range<1>(gws_range, lws_range),
579 Impl(arg_tp, res1_tp, res2_tp, nelems));
580 }
581 else {
582 static constexpr bool disable_sg_loadstore = false;
583 using KernelName =
584 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
585 using Impl =
586 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
587 n_vecs, disable_sg_loadstore>;
588
589 cgh.parallel_for<KernelName>(
590 sycl::nd_range<1>(gws_range, lws_range),
591 Impl(arg_tp, res1_tp, res2_tp, nelems));
592 }
593 });
594
595 return comp_ev;
596}
597
605template <typename argTy,
606 template <typename T> class UnaryTwoOutputsType,
607 template <typename A,
608 typename R1,
609 typename R2,
610 typename I> class UnaryTwoOutputsStridedFunctorT,
611 template <typename A,
612 typename R1,
613 typename R2,
614 typename I> class kernel_name>
615sycl::event unary_two_outputs_strided_impl(
616 sycl::queue &exec_q,
617 std::size_t nelems,
618 int nd,
619 const ssize_t *shape_and_strides,
620 const char *arg_p,
621 ssize_t arg_offset,
622 char *res1_p,
623 ssize_t res1_offset,
624 char *res2_p,
625 ssize_t res2_offset,
626 const std::vector<sycl::event> &depends,
627 const std::vector<sycl::event> &additional_depends)
628{
629 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
630 cgh.depends_on(depends);
631 cgh.depends_on(additional_depends);
632
633 using res1Ty = typename UnaryTwoOutputsType<argTy>::value_type1;
634 using res2Ty = typename UnaryTwoOutputsType<argTy>::value_type2;
635 using IndexerT =
636 typename dpnp::tensor::offset_utils::ThreeOffsets_StridedIndexer;
637
638 const IndexerT indexer{nd, arg_offset, res1_offset, res2_offset,
639 shape_and_strides};
640
641 const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
642 res1Ty *res1_tp = reinterpret_cast<res1Ty *>(res1_p);
643 res2Ty *res2_tp = reinterpret_cast<res2Ty *>(res2_p);
644
645 using Impl =
646 UnaryTwoOutputsStridedFunctorT<argTy, res1Ty, res2Ty, IndexerT>;
647
648 cgh.parallel_for<kernel_name<argTy, res1Ty, res2Ty, IndexerT>>(
649 {nelems}, Impl(arg_tp, res1_tp, res2_tp, indexer));
650 });
651 return comp_ev;
652}
653
661template <
662 typename argTy1,
663 typename argTy2,
664 template <typename T1, typename T2> class BinaryTwoOutputsType,
665 template <typename T1,
666 typename T2,
667 typename T3,
668 typename T4,
669 std::uint8_t vs,
670 std::uint8_t nv,
671 bool enable_sg_loadstore> class BinaryTwoOutputsContigFunctorT,
672 template <typename T1,
673 typename T2,
674 typename T3,
675 typename T4,
676 std::uint8_t vs,
677 std::uint8_t nv> class kernel_name,
678 std::uint8_t vec_sz = 4u,
679 std::uint8_t n_vecs = 2u>
680sycl::event
681 binary_two_outputs_contig_impl(sycl::queue &exec_q,
682 std::size_t nelems,
683 const char *arg1_p,
684 ssize_t arg1_offset,
685 const char *arg2_p,
686 ssize_t arg2_offset,
687 char *res1_p,
688 ssize_t res1_offset,
689 char *res2_p,
690 ssize_t res2_offset,
691 const std::vector<sycl::event> &depends = {})
692{
693 const std::size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
694 const std::size_t lws =
695 select_lws(exec_q.get_device(), n_work_items_needed);
696
697 const std::size_t n_groups =
698 ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
699 const auto gws_range = sycl::range<1>(n_groups * lws);
700 const auto lws_range = sycl::range<1>(lws);
701
702 using resTy1 = typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
703 using resTy2 = typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
704 using BaseKernelName =
705 kernel_name<argTy1, argTy2, resTy1, resTy2, vec_sz, n_vecs>;
706
707 const argTy1 *arg1_tp =
708 reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
709 const argTy2 *arg2_tp =
710 reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
711 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p) + res1_offset;
712 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p) + res2_offset;
713
714 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
715 cgh.depends_on(depends);
716
717 if (is_aligned<required_alignment>(arg1_tp) &&
718 is_aligned<required_alignment>(arg2_tp) &&
719 is_aligned<required_alignment>(res1_tp) &&
720 is_aligned<required_alignment>(res2_tp)) {
721 static constexpr bool enable_sg_loadstore = true;
722 using KernelName = BaseKernelName;
723 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
724 resTy2, vec_sz, n_vecs,
725 enable_sg_loadstore>;
726
727 cgh.parallel_for<KernelName>(
728 sycl::nd_range<1>(gws_range, lws_range),
729 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
730 }
731 else {
732 static constexpr bool disable_sg_loadstore = false;
733 using KernelName =
734 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
735 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
736 resTy2, vec_sz, n_vecs,
737 disable_sg_loadstore>;
738
739 cgh.parallel_for<KernelName>(
740 sycl::nd_range<1>(gws_range, lws_range),
741 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
742 }
743 });
744 return comp_ev;
745}
746
754template <typename argTy1,
755 typename argTy2,
756 template <typename T1, typename T2> class BinaryTwoOutputsType,
757 template <typename T1,
758 typename T2,
759 typename T3,
760 typename T4,
761 typename IndT> class BinaryTwoOutputsStridedFunctorT,
762 template <typename T1,
763 typename T2,
764 typename T3,
765 typename T4,
766 typename IndT> class kernel_name>
767sycl::event binary_two_outputs_strided_impl(
768 sycl::queue &exec_q,
769 std::size_t nelems,
770 int nd,
771 const ssize_t *shape_and_strides,
772 const char *arg1_p,
773 ssize_t arg1_offset,
774 const char *arg2_p,
775 ssize_t arg2_offset,
776 char *res1_p,
777 ssize_t res1_offset,
778 char *res2_p,
779 ssize_t res2_offset,
780 const std::vector<sycl::event> &depends,
781 const std::vector<sycl::event> &additional_depends)
782{
783 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
784 cgh.depends_on(depends);
785 cgh.depends_on(additional_depends);
786
787 using resTy1 =
788 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
789 using resTy2 =
790 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
791
792 using IndexerT =
793 typename dpnp::tensor::offset_utils::FourOffsets_StridedIndexer;
794
795 const IndexerT indexer{nd, arg1_offset, arg2_offset,
796 res1_offset, res2_offset, shape_and_strides};
797
798 const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
799 const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
800 resTy1 *res1_tp = reinterpret_cast<resTy1 *>(res1_p);
801 resTy2 *res2_tp = reinterpret_cast<resTy2 *>(res2_p);
802
803 using Impl = BinaryTwoOutputsStridedFunctorT<argTy1, argTy2, resTy1,
804 resTy2, IndexerT>;
805
806 cgh.parallel_for<kernel_name<argTy1, argTy2, resTy1, resTy2, IndexerT>>(
807 {nelems}, Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, indexer));
808 });
809 return comp_ev;
810}
811
812// Typedefs for function pointers
813
814typedef sycl::event (*unary_two_outputs_contig_impl_fn_ptr_t)(
815 sycl::queue &,
816 std::size_t,
817 const char *,
818 char *,
819 char *,
820 const std::vector<sycl::event> &);
821
822typedef sycl::event (*unary_two_outputs_strided_impl_fn_ptr_t)(
823 sycl::queue &,
824 std::size_t,
825 int,
826 const ssize_t *,
827 const char *,
828 ssize_t,
829 char *,
830 ssize_t,
831 char *,
832 ssize_t,
833 const std::vector<sycl::event> &,
834 const std::vector<sycl::event> &);
835
836typedef sycl::event (*binary_two_outputs_contig_impl_fn_ptr_t)(
837 sycl::queue &,
838 std::size_t,
839 const char *,
840 ssize_t,
841 const char *,
842 ssize_t,
843 char *,
844 ssize_t,
845 char *,
846 ssize_t,
847 const std::vector<sycl::event> &);
848
849typedef sycl::event (*binary_two_outputs_strided_impl_fn_ptr_t)(
850 sycl::queue &,
851 std::size_t,
852 int,
853 const ssize_t *,
854 const char *,
855 ssize_t,
856 const char *,
857 ssize_t,
858 char *,
859 ssize_t,
860 char *,
861 ssize_t,
862 const std::vector<sycl::event> &,
863 const std::vector<sycl::event> &);
864
865} // namespace dpnp::extensions::py_internal::elementwise_common
Functor for evaluation of a binary function with two output arrays on contiguous arrays.
Definition common.hpp:331
Functor for evaluation of a binary function with two output arrays on strided data.
Definition common.hpp:482
Functor for evaluation of a unary function with two output arrays on contiguous arrays.
Definition common.hpp:70
Functor for evaluation of a unary function with two output arrays on strided data.
Definition common.hpp:285