37#include <sycl/sycl.hpp>
40#include "kernels/alignment.hpp"
41#include "kernels/elementwise_functions/common.hpp"
42#include "utils/sycl_utils.hpp"
44namespace dpnp::extensions::py_internal::elementwise_common
46using dpctl::tensor::kernels::alignment_utils::
47 disabled_sg_loadstore_wrapper_krn;
48using dpctl::tensor::kernels::alignment_utils::is_aligned;
49using dpctl::tensor::kernels::alignment_utils::required_alignment;
51using dpctl::tensor::kernels::elementwise_common::select_lws;
53using dpctl::tensor::sycl_utils::sub_group_load;
54using dpctl::tensor::sycl_utils::sub_group_store;
63template <
typename argT,
66 typename UnaryTwoOutputsOpT,
67 std::uint8_t vec_sz = 4u,
68 std::uint8_t n_vecs = 2u,
69 bool enable_sg_loadstore =
true>
73 const argT *in =
nullptr;
74 resT1 *out1 =
nullptr;
75 resT2 *out2 =
nullptr;
82 const std::size_t n_elems)
83 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
87 void operator()(sycl::nd_item<1> ndit)
const
89 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
90 UnaryTwoOutputsOpT op{};
94 if constexpr (enable_sg_loadstore &&
95 UnaryTwoOutputsOpT::is_constant::value) {
97 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
98 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
100 auto sg = ndit.get_sub_group();
101 const std::uint16_t sgSize = sg.get_max_local_range()[0];
103 const std::size_t base =
104 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
105 sg.get_group_id()[0] * sgSize);
106 if (base + elems_per_wi * sgSize < nelems_) {
107 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
108 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
110 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
111 const std::size_t offset = base + it * sgSize;
112 auto out1_multi_ptr = sycl::address_space_cast<
113 sycl::access::address_space::global_space,
114 sycl::access::decorated::yes>(&out1[offset]);
115 auto out2_multi_ptr = sycl::address_space_cast<
116 sycl::access::address_space::global_space,
117 sycl::access::decorated::yes>(&out2[offset]);
119 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
120 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
124 const std::size_t lane_id = sg.get_local_id()[0];
125 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
126 out1[k] = const_val1;
127 out2[k] = const_val2;
131 else if constexpr (enable_sg_loadstore &&
132 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
133 UnaryTwoOutputsOpT::supports_vec::value &&
136 auto sg = ndit.get_sub_group();
137 const std::uint16_t sgSize = sg.get_max_local_range()[0];
139 const std::size_t base =
140 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
141 sg.get_group_id()[0] * sgSize);
142 if (base + elems_per_wi * sgSize < nelems_) {
144 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
145 const std::size_t offset = base + it * sgSize;
146 auto in_multi_ptr = sycl::address_space_cast<
147 sycl::access::address_space::global_space,
148 sycl::access::decorated::yes>(&in[offset]);
149 auto out1_multi_ptr = sycl::address_space_cast<
150 sycl::access::address_space::global_space,
151 sycl::access::decorated::yes>(&out1[offset]);
152 auto out2_multi_ptr = sycl::address_space_cast<
153 sycl::access::address_space::global_space,
154 sycl::access::decorated::yes>(&out2[offset]);
156 const sycl::vec<argT, vec_sz> x =
157 sub_group_load<vec_sz>(sg, in_multi_ptr);
158 sycl::vec<resT2, vec_sz> res2_vec = {};
159 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
160 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
161 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
165 const std::size_t lane_id = sg.get_local_id()[0];
166 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
168 out1[k] = op(in[k], out2[k]);
172 else if constexpr (enable_sg_loadstore &&
173 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
174 std::is_same_v<resT1, argT>)
178 auto sg = ndit.get_sub_group();
179 const std::uint16_t sgSize = sg.get_max_local_range()[0];
180 const std::size_t base =
181 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
182 sg.get_group_id()[0] * sgSize);
184 if (base + elems_per_wi * sgSize < nelems_) {
186 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
187 const std::size_t offset = base + it * sgSize;
188 auto in_multi_ptr = sycl::address_space_cast<
189 sycl::access::address_space::global_space,
190 sycl::access::decorated::yes>(&in[offset]);
191 auto out1_multi_ptr = sycl::address_space_cast<
192 sycl::access::address_space::global_space,
193 sycl::access::decorated::yes>(&out1[offset]);
194 auto out2_multi_ptr = sycl::address_space_cast<
195 sycl::access::address_space::global_space,
196 sycl::access::decorated::yes>(&out2[offset]);
198 sycl::vec<argT, vec_sz> arg_vec =
199 sub_group_load<vec_sz>(sg, in_multi_ptr);
200 sycl::vec<resT2, vec_sz> res2_vec = {};
202 for (std::uint32_t k = 0; k < vec_sz; ++k) {
203 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
205 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
206 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
210 const std::size_t lane_id = sg.get_local_id()[0];
211 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
212 out1[k] = op(in[k], out2[k]);
216 else if constexpr (enable_sg_loadstore &&
217 UnaryTwoOutputsOpT::supports_sg_loadstore::value)
221 auto sg = ndit.get_sub_group();
222 const std::uint16_t sgSize = sg.get_max_local_range()[0];
223 const std::size_t base =
224 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
225 sg.get_group_id()[0] * sgSize);
227 if (base + elems_per_wi * sgSize < nelems_) {
229 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
230 const std::size_t offset = base + it * sgSize;
231 auto in_multi_ptr = sycl::address_space_cast<
232 sycl::access::address_space::global_space,
233 sycl::access::decorated::yes>(&in[offset]);
234 auto out1_multi_ptr = sycl::address_space_cast<
235 sycl::access::address_space::global_space,
236 sycl::access::decorated::yes>(&out1[offset]);
237 auto out2_multi_ptr = sycl::address_space_cast<
238 sycl::access::address_space::global_space,
239 sycl::access::decorated::yes>(&out2[offset]);
241 const sycl::vec<argT, vec_sz> arg_vec =
242 sub_group_load<vec_sz>(sg, in_multi_ptr);
243 sycl::vec<resT1, vec_sz> res1_vec = {};
244 sycl::vec<resT2, vec_sz> res2_vec = {};
246 for (std::uint8_t k = 0; k < vec_sz; ++k) {
247 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
249 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
250 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
254 const std::size_t lane_id = sg.get_local_id()[0];
255 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
256 out1[k] = op(in[k], out2[k]);
261 const std::uint16_t sgSize =
262 ndit.get_sub_group().get_local_range()[0];
263 const std::size_t gid = ndit.get_global_linear_id();
264 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
266 const std::size_t start =
267 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
268 const std::size_t end = std::min(nelems_, start + elems_per_sg);
269 for (std::size_t offset = start; offset < end; offset += sgSize) {
270 out1[offset] = op(in[offset], out2[offset]);
283template <
typename argT,
287 typename UnaryTwoOutputsOpT>
291 const argT *inp_ =
nullptr;
292 resT1 *res1_ =
nullptr;
293 resT2 *res2_ =
nullptr;
294 IndexerT inp_out_indexer_;
300 const IndexerT &inp_out_indexer)
301 : inp_(inp_p), res1_(res1_p), res2_(res2_p),
302 inp_out_indexer_(inp_out_indexer)
306 void operator()(sycl::id<1> wid)
const
308 const auto &offsets_ = inp_out_indexer_(wid.get(0));
309 const ssize_t &inp_offset = offsets_.get_first_offset();
310 const ssize_t &res1_offset = offsets_.get_second_offset();
311 const ssize_t &res2_offset = offsets_.get_third_offset();
313 UnaryTwoOutputsOpT op{};
315 res1_[res1_offset] = op(inp_[inp_offset], res2_[res2_offset]);
326template <
typename argT1,
330 typename BinaryOperatorT,
331 std::uint8_t vec_sz = 4u,
332 std::uint8_t n_vecs = 2u,
333 bool enable_sg_loadstore =
true>
337 const argT1 *in1 =
nullptr;
338 const argT2 *in2 =
nullptr;
339 resT1 *out1 =
nullptr;
340 resT2 *out2 =
nullptr;
349 : in1(inp1), in2(inp2), out1(res1), out2(res2), nelems_(n_elems)
353 void operator()(sycl::nd_item<1> ndit)
const
355 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
356 BinaryOperatorT op{};
360 if constexpr (enable_sg_loadstore &&
361 BinaryOperatorT::supports_sg_loadstore::value &&
362 BinaryOperatorT::supports_vec::value && (vec_sz > 1))
364 auto sg = ndit.get_sub_group();
365 std::uint16_t sgSize = sg.get_max_local_range()[0];
367 const std::size_t base =
368 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
369 sg.get_group_id()[0] * sgSize);
371 if (base + elems_per_wi * sgSize < nelems_) {
372 sycl::vec<resT1, vec_sz> res1_vec;
373 sycl::vec<resT2, vec_sz> res2_vec;
376 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
377 std::size_t offset = base + it * sgSize;
378 auto in1_multi_ptr = sycl::address_space_cast<
379 sycl::access::address_space::global_space,
380 sycl::access::decorated::yes>(&in1[offset]);
381 auto in2_multi_ptr = sycl::address_space_cast<
382 sycl::access::address_space::global_space,
383 sycl::access::decorated::yes>(&in2[offset]);
384 auto out1_multi_ptr = sycl::address_space_cast<
385 sycl::access::address_space::global_space,
386 sycl::access::decorated::yes>(&out1[offset]);
387 auto out2_multi_ptr = sycl::address_space_cast<
388 sycl::access::address_space::global_space,
389 sycl::access::decorated::yes>(&out2[offset]);
391 const sycl::vec<argT1, vec_sz> arg1_vec =
392 sub_group_load<vec_sz>(sg, in1_multi_ptr);
393 const sycl::vec<argT2, vec_sz> arg2_vec =
394 sub_group_load<vec_sz>(sg, in2_multi_ptr);
395 res1_vec = op(arg1_vec, arg2_vec, res2_vec);
396 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
397 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
401 const std::size_t lane_id = sg.get_local_id()[0];
402 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
403 out1[k] = op(in1[k], in2[k], out2[k]);
407 else if constexpr (enable_sg_loadstore &&
408 BinaryOperatorT::supports_sg_loadstore::value)
410 auto sg = ndit.get_sub_group();
411 const std::uint16_t sgSize = sg.get_max_local_range()[0];
413 const std::size_t base =
414 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
415 sg.get_group_id()[0] * sgSize);
417 if (base + elems_per_wi * sgSize < nelems_) {
419 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
420 const std::size_t offset = base + it * sgSize;
421 auto in1_multi_ptr = sycl::address_space_cast<
422 sycl::access::address_space::global_space,
423 sycl::access::decorated::yes>(&in1[offset]);
424 auto in2_multi_ptr = sycl::address_space_cast<
425 sycl::access::address_space::global_space,
426 sycl::access::decorated::yes>(&in2[offset]);
427 auto out1_multi_ptr = sycl::address_space_cast<
428 sycl::access::address_space::global_space,
429 sycl::access::decorated::yes>(&out1[offset]);
430 auto out2_multi_ptr = sycl::address_space_cast<
431 sycl::access::address_space::global_space,
432 sycl::access::decorated::yes>(&out2[offset]);
434 const sycl::vec<argT1, vec_sz> arg1_vec =
435 sub_group_load<vec_sz>(sg, in1_multi_ptr);
436 const sycl::vec<argT2, vec_sz> arg2_vec =
437 sub_group_load<vec_sz>(sg, in2_multi_ptr);
439 sycl::vec<resT1, vec_sz> res1_vec;
440 sycl::vec<resT2, vec_sz> res2_vec;
442 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
444 op(arg1_vec[vec_id], arg2_vec[vec_id],
447 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
448 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
452 const std::size_t lane_id = sg.get_local_id()[0];
453 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
454 out1[k] = op(in1[k], in2[k], out2[k]);
459 const std::size_t sgSize =
460 ndit.get_sub_group().get_local_range()[0];
461 const std::size_t gid = ndit.get_global_linear_id();
462 const std::size_t elems_per_sg = sgSize * elems_per_wi;
464 const std::size_t start =
465 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
466 const std::size_t end = std::min(nelems_, start + elems_per_sg);
467 for (std::size_t offset = start; offset < end; offset += sgSize) {
468 out1[offset] = op(in1[offset], in2[offset], out2[offset]);
481template <
typename argT1,
485 typename FourOffsets_IndexerT,
486 typename BinaryOperatorT>
490 const argT1 *in1 =
nullptr;
491 const argT2 *in2 =
nullptr;
492 resT1 *out1 =
nullptr;
493 resT2 *out2 =
nullptr;
494 FourOffsets_IndexerT four_offsets_indexer_;
498 const argT2 *inp2_tp,
501 const FourOffsets_IndexerT &inps_res_indexer)
502 : in1(inp1_tp), in2(inp2_tp), out1(res1_tp), out2(res2_tp),
503 four_offsets_indexer_(inps_res_indexer)
507 void operator()(sycl::id<1> wid)
const
509 const auto &four_offsets_ =
510 four_offsets_indexer_(
static_cast<ssize_t
>(wid.get(0)));
512 const auto &inp1_offset = four_offsets_.get_first_offset();
513 const auto &inp2_offset = four_offsets_.get_second_offset();
514 const auto &out1_offset = four_offsets_.get_third_offset();
515 const auto &out2_offset = four_offsets_.get_fourth_offset();
517 BinaryOperatorT op{};
519 op(in1[inp1_offset], in2[inp2_offset], out2[out2_offset]);
530template <
typename argTy,
531 template <
typename T>
532 class UnaryTwoOutputsType,
533 template <
typename A,
539 class UnaryTwoOutputsContigFunctorT,
540 template <typename A,
546 std::uint8_t vec_sz = 4u,
547 std::uint8_t n_vecs = 2u>
549 unary_two_outputs_contig_impl(sycl::queue &exec_q,
554 const std::vector<sycl::event> &depends = {})
556 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
557 const std::size_t n_work_items_needed = nelems / elems_per_wi;
558 const std::size_t lws =
559 select_lws(exec_q.get_device(), n_work_items_needed);
561 const std::size_t n_groups =
562 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
563 const auto gws_range = sycl::range<1>(n_groups * lws);
564 const auto lws_range = sycl::range<1>(lws);
566 using resTy1 =
typename UnaryTwoOutputsType<argTy>::value_type1;
567 using resTy2 =
typename UnaryTwoOutputsType<argTy>::value_type2;
568 using BaseKernelName = kernel_name<argTy, resTy1, resTy2, vec_sz, n_vecs>;
570 const argTy *arg_tp =
reinterpret_cast<const argTy *
>(arg_p);
571 resTy1 *res1_tp =
reinterpret_cast<resTy1 *
>(res1_p);
572 resTy2 *res2_tp =
reinterpret_cast<resTy2 *
>(res2_p);
574 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
575 cgh.depends_on(depends);
577 if (is_aligned<required_alignment>(arg_p) &&
578 is_aligned<required_alignment>(res1_p) &&
579 is_aligned<required_alignment>(res2_p))
581 static constexpr bool enable_sg_loadstore =
true;
582 using KernelName = BaseKernelName;
584 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
585 n_vecs, enable_sg_loadstore>;
587 cgh.parallel_for<KernelName>(
588 sycl::nd_range<1>(gws_range, lws_range),
589 Impl(arg_tp, res1_tp, res2_tp, nelems));
592 static constexpr bool disable_sg_loadstore =
false;
594 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
596 UnaryTwoOutputsContigFunctorT<argTy, resTy1, resTy2, vec_sz,
597 n_vecs, disable_sg_loadstore>;
599 cgh.parallel_for<KernelName>(
600 sycl::nd_range<1>(gws_range, lws_range),
601 Impl(arg_tp, res1_tp, res2_tp, nelems));
615template <
typename argTy,
616 template <
typename T>
617 class UnaryTwoOutputsType,
618 template <
typename A,
typename R1,
typename R2,
typename I>
619 class UnaryTwoOutputsStridedFunctorT,
620 template <
typename A,
typename R1,
typename R2,
typename I>
622sycl::event unary_two_outputs_strided_impl(
626 const ssize_t *shape_and_strides,
633 const std::vector<sycl::event> &depends,
634 const std::vector<sycl::event> &additional_depends)
636 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
637 cgh.depends_on(depends);
638 cgh.depends_on(additional_depends);
640 using res1Ty =
typename UnaryTwoOutputsType<argTy>::value_type1;
641 using res2Ty =
typename UnaryTwoOutputsType<argTy>::value_type2;
643 typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
645 const IndexerT indexer{nd, arg_offset, res1_offset, res2_offset,
648 const argTy *arg_tp =
reinterpret_cast<const argTy *
>(arg_p);
649 res1Ty *res1_tp =
reinterpret_cast<res1Ty *
>(res1_p);
650 res2Ty *res2_tp =
reinterpret_cast<res2Ty *
>(res2_p);
653 UnaryTwoOutputsStridedFunctorT<argTy, res1Ty, res2Ty, IndexerT>;
655 cgh.parallel_for<kernel_name<argTy, res1Ty, res2Ty, IndexerT>>(
656 {nelems}, Impl(arg_tp, res1_tp, res2_tp, indexer));
668template <
typename argTy1,
670 template <
typename T1,
typename T2>
671 class BinaryTwoOutputsType,
672 template <
typename T1,
678 bool enable_sg_loadstore>
679 class BinaryTwoOutputsContigFunctorT,
680 template <typename T1,
687 std::uint8_t vec_sz = 4u,
688 std::uint8_t n_vecs = 2u>
690 binary_two_outputs_contig_impl(sycl::queue &exec_q,
700 const std::vector<sycl::event> &depends = {})
702 const std::size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
703 const std::size_t lws =
704 select_lws(exec_q.get_device(), n_work_items_needed);
706 const std::size_t n_groups =
707 ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
708 const auto gws_range = sycl::range<1>(n_groups * lws);
709 const auto lws_range = sycl::range<1>(lws);
711 using resTy1 =
typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
712 using resTy2 =
typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
713 using BaseKernelName =
714 kernel_name<argTy1, argTy2, resTy1, resTy2, vec_sz, n_vecs>;
716 const argTy1 *arg1_tp =
717 reinterpret_cast<const argTy1 *
>(arg1_p) + arg1_offset;
718 const argTy2 *arg2_tp =
719 reinterpret_cast<const argTy2 *
>(arg2_p) + arg2_offset;
720 resTy1 *res1_tp =
reinterpret_cast<resTy1 *
>(res1_p) + res1_offset;
721 resTy2 *res2_tp =
reinterpret_cast<resTy2 *
>(res2_p) + res2_offset;
723 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
724 cgh.depends_on(depends);
726 if (is_aligned<required_alignment>(arg1_tp) &&
727 is_aligned<required_alignment>(arg2_tp) &&
728 is_aligned<required_alignment>(res1_tp) &&
729 is_aligned<required_alignment>(res2_tp))
731 static constexpr bool enable_sg_loadstore =
true;
732 using KernelName = BaseKernelName;
733 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
734 resTy2, vec_sz, n_vecs,
735 enable_sg_loadstore>;
737 cgh.parallel_for<KernelName>(
738 sycl::nd_range<1>(gws_range, lws_range),
739 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
742 static constexpr bool disable_sg_loadstore =
false;
744 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
745 using Impl = BinaryTwoOutputsContigFunctorT<argTy1, argTy2, resTy1,
746 resTy2, vec_sz, n_vecs,
747 disable_sg_loadstore>;
749 cgh.parallel_for<KernelName>(
750 sycl::nd_range<1>(gws_range, lws_range),
751 Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, nelems));
767 template <
typename T1,
typename T2>
768 class BinaryTwoOutputsType,
769 template <
typename T1,
typename T2,
typename T3,
typename T4,
typename IndT>
770 class BinaryTwoOutputsStridedFunctorT,
771 template <
typename T1,
typename T2,
typename T3,
typename T4,
typename IndT>
773sycl::event binary_two_outputs_strided_impl(
777 const ssize_t *shape_and_strides,
786 const std::vector<sycl::event> &depends,
787 const std::vector<sycl::event> &additional_depends)
789 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
790 cgh.depends_on(depends);
791 cgh.depends_on(additional_depends);
794 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type1;
796 typename BinaryTwoOutputsType<argTy1, argTy2>::value_type2;
799 typename dpctl::tensor::offset_utils::FourOffsets_StridedIndexer;
801 const IndexerT indexer{nd, arg1_offset, arg2_offset,
802 res1_offset, res2_offset, shape_and_strides};
804 const argTy1 *arg1_tp =
reinterpret_cast<const argTy1 *
>(arg1_p);
805 const argTy2 *arg2_tp =
reinterpret_cast<const argTy2 *
>(arg2_p);
806 resTy1 *res1_tp =
reinterpret_cast<resTy1 *
>(res1_p);
807 resTy2 *res2_tp =
reinterpret_cast<resTy2 *
>(res2_p);
809 using Impl = BinaryTwoOutputsStridedFunctorT<argTy1, argTy2, resTy1,
812 cgh.parallel_for<kernel_name<argTy1, argTy2, resTy1, resTy2, IndexerT>>(
813 {nelems}, Impl(arg1_tp, arg2_tp, res1_tp, res2_tp, indexer));
820typedef sycl::event (*unary_two_outputs_contig_impl_fn_ptr_t)(
826 const std::vector<sycl::event> &);
828typedef sycl::event (*unary_two_outputs_strided_impl_fn_ptr_t)(
839 const std::vector<sycl::event> &,
840 const std::vector<sycl::event> &);
842typedef sycl::event (*binary_two_outputs_contig_impl_fn_ptr_t)(
853 const std::vector<sycl::event> &);
855typedef sycl::event (*binary_two_outputs_strided_impl_fn_ptr_t)(
868 const std::vector<sycl::event> &,
869 const std::vector<sycl::event> &);
Functor for evaluation of a binary function with two output arrays on contiguous arrays.
Functor for evaluation of a binary function with two output arrays on strided data.
Functor for evaluation of a unary function with two output arrays on contiguous arrays.
Functor for evaluation of a unary function with two output arrays on strided data.