DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <exception>
33#include <iterator>
34#include <stdexcept>
35#include <utility>
36#include <vector>
37
38#include <pybind11/numpy.h>
39#include <pybind11/pybind11.h>
40
41#include <sycl/sycl.hpp>
42
43#include "dpnp4pybind11.hpp"
44
45#include "elementwise_functions_type_utils.hpp"
46#include "simplify_iteration_space.hpp"
47
48// dpnp tensor headers
49#include "kernels/alignment.hpp"
50#include "utils/memory_overlap.hpp"
51#include "utils/offset_utils.hpp"
52#include "utils/output_validation.hpp"
53#include "utils/sycl_alloc_utils.hpp"
54#include "utils/type_dispatch.hpp"
55
56static_assert(std::is_same_v<py::ssize_t, dpnp::tensor::ssize_t>);
57
58namespace dpnp::extensions::py_internal
59{
60namespace py = pybind11;
61namespace td_ns = dpnp::tensor::type_dispatch;
62
63using dpnp::tensor::kernels::alignment_utils::is_aligned;
64using dpnp::tensor::kernels::alignment_utils::required_alignment;
65
66using type_utils::_result_typeid;
67
69template <typename output_typesT,
70 typename contig_dispatchT,
71 typename strided_dispatchT>
72std::pair<sycl::event, sycl::event>
73 py_unary_ufunc(const dpnp::tensor::usm_ndarray &src,
75 sycl::queue &q,
76 const std::vector<sycl::event> &depends,
77 //
78 const output_typesT &output_type_vec,
79 const contig_dispatchT &contig_dispatch_vector,
80 const strided_dispatchT &strided_dispatch_vector)
81{
82 int src_typenum = src.get_typenum();
83 int dst_typenum = dst.get_typenum();
84
85 const auto &array_types = td_ns::usm_ndarray_types();
86 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
87 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
88
89 int func_output_typeid = output_type_vec[src_typeid];
90
91 // check that types are supported
92 if (dst_typeid != func_output_typeid) {
93 throw py::value_error(
94 "Destination array has unexpected elemental data type.");
95 }
96
97 // check that queues are compatible
98 if (!dpnp::utils::queues_are_compatible(q, {src, dst})) {
99 throw py::value_error(
100 "Execution queue is not compatible with allocation queues");
101 }
102
103 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst);
104
105 // check that dimensions are the same
106 int src_nd = src.get_ndim();
107 if (src_nd != dst.get_ndim()) {
108 throw py::value_error("Array dimensions are not the same.");
109 }
110
111 // check that shapes are the same
112 const py::ssize_t *src_shape = src.get_shape_raw();
113 const py::ssize_t *dst_shape = dst.get_shape_raw();
114 bool shapes_equal(true);
115 std::size_t src_nelems(1);
116
117 for (int i = 0; i < src_nd; ++i) {
118 src_nelems *= static_cast<std::size_t>(src_shape[i]);
119 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
120 }
121 if (!shapes_equal) {
122 throw py::value_error("Array shapes are not the same.");
123 }
124
125 // if nelems is zero, return
126 if (src_nelems == 0) {
127 return std::make_pair(sycl::event(), sycl::event());
128 }
129
130 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
131
132 // check memory overlap
133 auto const &overlap = dpnp::tensor::overlap::MemoryOverlap();
134 auto const &same_logical_tensors =
135 dpnp::tensor::overlap::SameLogicalTensors();
136 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
137 throw py::value_error("Arrays index overlapping segments of memory");
138 }
139
140 const char *src_data = src.get_data();
141 char *dst_data = dst.get_data();
142
143 // handle contiguous inputs
144 bool is_src_c_contig = src.is_c_contiguous();
145 bool is_src_f_contig = src.is_f_contiguous();
146
147 bool is_dst_c_contig = dst.is_c_contiguous();
148 bool is_dst_f_contig = dst.is_f_contiguous();
149
150 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
151 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
152
153 if (both_c_contig || both_f_contig) {
154 auto contig_fn = contig_dispatch_vector[src_typeid];
155
156 if (contig_fn == nullptr) {
157 throw std::runtime_error(
158 "Contiguous implementation is missing for src_typeid=" +
159 std::to_string(src_typeid));
160 }
161
162 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
163 sycl::event ht_ev =
164 dpnp::utils::keep_args_alive(q, {src, dst}, {comp_ev});
165
166 return std::make_pair(ht_ev, comp_ev);
167 }
168
169 // simplify iteration space
170 // if 1d with strides 1 - input is contig
171 // dispatch to strided
172
173 auto const &src_strides = src.get_strides_vector();
174 auto const &dst_strides = dst.get_strides_vector();
175
176 using shT = std::vector<py::ssize_t>;
177 shT simplified_shape;
178 shT simplified_src_strides;
179 shT simplified_dst_strides;
180 py::ssize_t src_offset(0);
181 py::ssize_t dst_offset(0);
182
183 int nd = src_nd;
184 const py::ssize_t *shape = src_shape;
185
186 simplify_iteration_space(nd, shape, src_strides, dst_strides,
187 // output
188 simplified_shape, simplified_src_strides,
189 simplified_dst_strides, src_offset, dst_offset);
190
191 if (nd == 1 && simplified_src_strides[0] == 1 &&
192 simplified_dst_strides[0] == 1) {
193 // Special case of contiguous data
194 auto contig_fn = contig_dispatch_vector[src_typeid];
195
196 if (contig_fn == nullptr) {
197 throw std::runtime_error(
198 "Contiguous implementation is missing for src_typeid=" +
199 std::to_string(src_typeid));
200 }
201
202 int src_elem_size = src.get_elemsize();
203 int dst_elem_size = dst.get_elemsize();
204 auto comp_ev =
205 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
206 dst_data + dst_elem_size * dst_offset, depends);
207
208 sycl::event ht_ev =
209 dpnp::utils::keep_args_alive(q, {src, dst}, {comp_ev});
210
211 return std::make_pair(ht_ev, comp_ev);
212 }
213
214 // Strided implementation
215 auto strided_fn = strided_dispatch_vector[src_typeid];
216
217 if (strided_fn == nullptr) {
218 throw std::runtime_error(
219 "Strided implementation is missing for src_typeid=" +
220 std::to_string(src_typeid));
221 }
222
223 using dpnp::tensor::offset_utils::device_allocate_and_pack;
224
225 std::vector<sycl::event> host_tasks{};
226 host_tasks.reserve(2);
227
228 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
229 q, host_tasks, simplified_shape, simplified_src_strides,
230 simplified_dst_strides);
231 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
232 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
233 const py::ssize_t *shape_strides = shape_strides_owner.get();
234
235 sycl::event strided_fn_ev =
236 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
237 dst_data, dst_offset, depends, {copy_shape_ev});
238
239 // async free of shape_strides temporary
240 sycl::event tmp_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free(
241 q, {strided_fn_ev}, shape_strides_owner);
242
243 host_tasks.push_back(tmp_cleanup_ev);
244
245 return std::make_pair(
246 dpnp::utils::keep_args_alive(q, {src, dst}, host_tasks), strided_fn_ev);
247}
248
251template <typename output_typesT>
252py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
253 const output_typesT &output_types)
254{
255 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpnp
256 int src_typeid = -1;
257
258 auto array_types = td_ns::usm_ndarray_types();
259
260 try {
261 src_typeid = array_types.typenum_to_lookup_id(tn);
262 } catch (const std::exception &e) {
263 throw py::value_error(e.what());
264 }
265
266 int dst_typeid = _result_typeid(src_typeid, output_types);
267 if (dst_typeid < 0) {
268 auto res = py::none();
269 return py::cast<py::object>(res);
270 }
271 else {
272 using type_utils::_dtype_from_typenum;
273
274 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
275 auto dt = _dtype_from_typenum(dst_typenum_t);
276
277 return py::cast<py::object>(dt);
278 }
279}
280
285template <typename output_typesT,
286 typename contig_dispatchT,
287 typename strided_dispatchT>
288std::pair<sycl::event, sycl::event>
289 py_unary_two_outputs_ufunc(const dpnp::tensor::usm_ndarray &src,
290 const dpnp::tensor::usm_ndarray &dst1,
291 const dpnp::tensor::usm_ndarray &dst2,
292 sycl::queue &q,
293 const std::vector<sycl::event> &depends,
294 //
295 const output_typesT &output_type_vec,
296 const contig_dispatchT &contig_dispatch_vector,
297 const strided_dispatchT &strided_dispatch_vector)
298{
299 int src_typenum = src.get_typenum();
300 int dst1_typenum = dst1.get_typenum();
301 int dst2_typenum = dst2.get_typenum();
302
303 const auto &array_types = td_ns::usm_ndarray_types();
304 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
305 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
306 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
307
308 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
309
310 // check that types are supported
311 if (dst1_typeid != func_output_typeids.first ||
312 dst2_typeid != func_output_typeids.second) {
313 throw py::value_error(
314 "One of destination arrays has unexpected elemental data type.");
315 }
316
317 // check that queues are compatible
318 if (!dpnp::utils::queues_are_compatible(q, {src, dst1, dst2})) {
319 throw py::value_error(
320 "Execution queue is not compatible with allocation queues");
321 }
322
323 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
324 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
325
326 // check that dimensions are the same
327 int src_nd = src.get_ndim();
328 if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
329 throw py::value_error("Array dimensions are not the same.");
330 }
331
332 // check that shapes are the same
333 const py::ssize_t *src_shape = src.get_shape_raw();
334 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
335 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
336 bool shapes_equal(true);
337 std::size_t src_nelems(1);
338
339 for (int i = 0; i < src_nd; ++i) {
340 src_nelems *= static_cast<std::size_t>(src_shape[i]);
341 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
342 (src_shape[i] == dst2_shape[i]);
343 }
344 if (!shapes_equal) {
345 throw py::value_error("Array shapes are not the same.");
346 }
347
348 // if nelems is zero, return
349 if (src_nelems == 0) {
350 return std::make_pair(sycl::event(), sycl::event());
351 }
352
353 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst1, src_nelems);
354 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst2, src_nelems);
355
356 // check memory overlap
357 auto const &overlap = dpnp::tensor::overlap::MemoryOverlap();
358 auto const &same_logical_tensors =
359 dpnp::tensor::overlap::SameLogicalTensors();
360 if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
361 (overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
362 (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2))) {
363 throw py::value_error("Arrays index overlapping segments of memory");
364 }
365
366 const char *src_data = src.get_data();
367 char *dst1_data = dst1.get_data();
368 char *dst2_data = dst2.get_data();
369
370 // handle contiguous inputs
371 bool is_src_c_contig = src.is_c_contiguous();
372 bool is_src_f_contig = src.is_f_contiguous();
373
374 bool is_dst1_c_contig = dst1.is_c_contiguous();
375 bool is_dst1_f_contig = dst1.is_f_contiguous();
376
377 bool is_dst2_c_contig = dst2.is_c_contiguous();
378 bool is_dst2_f_contig = dst2.is_f_contiguous();
379
380 bool all_c_contig =
381 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
382 bool all_f_contig =
383 (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
384
385 if (all_c_contig || all_f_contig) {
386 auto contig_fn = contig_dispatch_vector[src_typeid];
387
388 if (contig_fn == nullptr) {
389 throw std::runtime_error(
390 "Contiguous implementation is missing for src_typeid=" +
391 std::to_string(src_typeid));
392 }
393
394 auto comp_ev =
395 contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
396 sycl::event ht_ev =
397 dpnp::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
398
399 return std::make_pair(ht_ev, comp_ev);
400 }
401
402 // simplify iteration space
403 // if 1d with strides 1 - input is contig
404 // dispatch to strided
405
406 auto const &src_strides = src.get_strides_vector();
407 auto const &dst1_strides = dst1.get_strides_vector();
408 auto const &dst2_strides = dst2.get_strides_vector();
409
410 using shT = std::vector<py::ssize_t>;
411 shT simplified_shape;
412 shT simplified_src_strides;
413 shT simplified_dst1_strides;
414 shT simplified_dst2_strides;
415 py::ssize_t src_offset(0);
416 py::ssize_t dst1_offset(0);
417 py::ssize_t dst2_offset(0);
418
419 int nd = src_nd;
420 const py::ssize_t *shape = src_shape;
421
422 simplify_iteration_space_3(
423 nd, shape, src_strides, dst1_strides, dst2_strides,
424 // output
425 simplified_shape, simplified_src_strides, simplified_dst1_strides,
426 simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
427
428 if (nd == 1 && simplified_src_strides[0] == 1 &&
429 simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1) {
430 // Special case of contiguous data
431 auto contig_fn = contig_dispatch_vector[src_typeid];
432
433 if (contig_fn == nullptr) {
434 throw std::runtime_error(
435 "Contiguous implementation is missing for src_typeid=" +
436 std::to_string(src_typeid));
437 }
438
439 int src_elem_size = src.get_elemsize();
440 int dst1_elem_size = dst1.get_elemsize();
441 int dst2_elem_size = dst2.get_elemsize();
442 auto comp_ev =
443 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
444 dst1_data + dst1_elem_size * dst1_offset,
445 dst2_data + dst2_elem_size * dst2_offset, depends);
446
447 sycl::event ht_ev =
448 dpnp::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
449
450 return std::make_pair(ht_ev, comp_ev);
451 }
452
453 // Strided implementation
454 auto strided_fn = strided_dispatch_vector[src_typeid];
455
456 if (strided_fn == nullptr) {
457 throw std::runtime_error(
458 "Strided implementation is missing for src_typeid=" +
459 std::to_string(src_typeid));
460 }
461
462 using dpnp::tensor::offset_utils::device_allocate_and_pack;
463
464 std::vector<sycl::event> host_tasks{};
465 host_tasks.reserve(2);
466
467 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
468 q, host_tasks, simplified_shape, simplified_src_strides,
469 simplified_dst1_strides, simplified_dst2_strides);
470 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
471 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
472 const py::ssize_t *shape_strides = shape_strides_owner.get();
473
474 sycl::event strided_fn_ev = strided_fn(
475 q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
476 dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
477
478 // async free of shape_strides temporary
479 sycl::event tmp_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free(
480 q, {strided_fn_ev}, shape_strides_owner);
481
482 host_tasks.push_back(tmp_cleanup_ev);
483
484 return std::make_pair(
485 dpnp::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
486 strided_fn_ev);
487}
488
493template <typename output_typesT>
494std::pair<py::object, py::object>
495 py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
496 const output_typesT &output_types)
497{
498 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpnp
499 int src_typeid = -1;
500
501 auto array_types = td_ns::usm_ndarray_types();
502
503 try {
504 src_typeid = array_types.typenum_to_lookup_id(tn);
505 } catch (const std::exception &e) {
506 throw py::value_error(e.what());
507 }
508
509 std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
510 int dst1_typeid = dst_typeids.first;
511 int dst2_typeid = dst_typeids.second;
512
513 if (dst1_typeid < 0 || dst2_typeid < 0) {
514 auto res = py::none();
515 auto py_res = py::cast<py::object>(res);
516 return std::make_pair(py_res, py_res);
517 }
518 else {
519 using type_utils::_dtype_from_typenum;
520
521 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
522 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
523 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
524 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
525
526 return std::make_pair(py::cast<py::object>(dt1),
527 py::cast<py::object>(dt2));
528 }
529}
530
531// ======================== Binary functions ===========================
532
533namespace
534{
535template <class Container, class T>
536bool isEqual(Container const &c, std::initializer_list<T> const &l)
537{
538 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
539}
540} // namespace
541
544template <typename output_typesT,
545 typename contig_dispatchT,
546 typename strided_dispatchT,
547 typename contig_matrix_row_dispatchT,
548 typename contig_row_matrix_dispatchT>
549std::pair<sycl::event, sycl::event> py_binary_ufunc(
550 const dpnp::tensor::usm_ndarray &src1,
551 const dpnp::tensor::usm_ndarray &src2,
552 const dpnp::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
553 sycl::queue &exec_q,
554 const std::vector<sycl::event> &depends,
555 //
556 const output_typesT &output_type_table,
557 const contig_dispatchT &contig_dispatch_table,
558 const strided_dispatchT &strided_dispatch_table,
559 const contig_matrix_row_dispatchT
560 &contig_matrix_row_broadcast_dispatch_table,
561 const contig_row_matrix_dispatchT
562 &contig_row_matrix_broadcast_dispatch_table)
563{
564 // check type_nums
565 int src1_typenum = src1.get_typenum();
566 int src2_typenum = src2.get_typenum();
567 int dst_typenum = dst.get_typenum();
568
569 auto array_types = td_ns::usm_ndarray_types();
570 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
571 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
572 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
573
574 int output_typeid = output_type_table[src1_typeid][src2_typeid];
575
576 if (output_typeid != dst_typeid) {
577 throw py::value_error(
578 "Destination array has unexpected elemental data type.");
579 }
580
581 // check that queues are compatible
582 if (!dpnp::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
583 throw py::value_error(
584 "Execution queue is not compatible with allocation queues");
585 }
586
587 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst);
588
589 // check shapes, broadcasting is assumed done by caller
590 // check that dimensions are the same
591 int dst_nd = dst.get_ndim();
592 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
593 throw py::value_error("Array dimensions are not the same.");
594 }
595
596 // check that shapes are the same
597 const py::ssize_t *src1_shape = src1.get_shape_raw();
598 const py::ssize_t *src2_shape = src2.get_shape_raw();
599 const py::ssize_t *dst_shape = dst.get_shape_raw();
600 bool shapes_equal(true);
601 std::size_t src_nelems(1);
602
603 for (int i = 0; i < dst_nd; ++i) {
604 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
605 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
606 src2_shape[i] == dst_shape[i]);
607 }
608 if (!shapes_equal) {
609 throw py::value_error("Array shapes are not the same.");
610 }
611
612 // if nelems is zero, return
613 if (src_nelems == 0) {
614 return std::make_pair(sycl::event(), sycl::event());
615 }
616
617 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
618
619 auto const &overlap = dpnp::tensor::overlap::MemoryOverlap();
620 auto const &same_logical_tensors =
621 dpnp::tensor::overlap::SameLogicalTensors();
622 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
623 (overlap(src2, dst) && !same_logical_tensors(src2, dst))) {
624 throw py::value_error("Arrays index overlapping segments of memory");
625 }
626 // check memory overlap
627 const char *src1_data = src1.get_data();
628 const char *src2_data = src2.get_data();
629 char *dst_data = dst.get_data();
630
631 // handle contiguous inputs
632 bool is_src1_c_contig = src1.is_c_contiguous();
633 bool is_src1_f_contig = src1.is_f_contiguous();
634
635 bool is_src2_c_contig = src2.is_c_contiguous();
636 bool is_src2_f_contig = src2.is_f_contiguous();
637
638 bool is_dst_c_contig = dst.is_c_contiguous();
639 bool is_dst_f_contig = dst.is_f_contiguous();
640
641 bool all_c_contig =
642 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
643 bool all_f_contig =
644 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
645
646 // dispatch for contiguous inputs
647 if (all_c_contig || all_f_contig) {
648 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
649
650 if (contig_fn != nullptr) {
651 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
652 src2_data, 0, dst_data, 0, depends);
653 sycl::event ht_ev = dpnp::utils::keep_args_alive(
654 exec_q, {src1, src2, dst}, {comp_ev});
655
656 return std::make_pair(ht_ev, comp_ev);
657 }
658 }
659
660 // simplify strides
661 auto const &src1_strides = src1.get_strides_vector();
662 auto const &src2_strides = src2.get_strides_vector();
663 auto const &dst_strides = dst.get_strides_vector();
664
665 using shT = std::vector<py::ssize_t>;
666 shT simplified_shape;
667 shT simplified_src1_strides;
668 shT simplified_src2_strides;
669 shT simplified_dst_strides;
670 py::ssize_t src1_offset(0);
671 py::ssize_t src2_offset(0);
672 py::ssize_t dst_offset(0);
673
674 int nd = dst_nd;
675 const py::ssize_t *shape = src1_shape;
676
677 simplify_iteration_space_3(
678 nd, shape, src1_strides, src2_strides, dst_strides,
679 // outputs
680 simplified_shape, simplified_src1_strides, simplified_src2_strides,
681 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
682
683 std::vector<sycl::event> host_tasks{};
684 if (nd < 3) {
685 static constexpr auto unit_stride =
686 std::initializer_list<py::ssize_t>{1};
687
688 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
689 isEqual(simplified_src2_strides, unit_stride) &&
690 isEqual(simplified_dst_strides, unit_stride)) {
691 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
692
693 if (contig_fn != nullptr) {
694 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
695 src1_offset, src2_data, src2_offset,
696 dst_data, dst_offset, depends);
697 sycl::event ht_ev = dpnp::utils::keep_args_alive(
698 exec_q, {src1, src2, dst}, {comp_ev});
699
700 return std::make_pair(ht_ev, comp_ev);
701 }
702 }
703 if (nd == 2) {
704 static constexpr auto zero_one_strides =
705 std::initializer_list<py::ssize_t>{0, 1};
706 static constexpr auto one_zero_strides =
707 std::initializer_list<py::ssize_t>{1, 0};
708 static constexpr py::ssize_t one{1};
709 // special case of C-contiguous matrix and a row
710 if (isEqual(simplified_src2_strides, zero_one_strides) &&
711 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
712 isEqual(simplified_dst_strides, {simplified_shape[1], one})) {
713 auto matrix_row_broadcast_fn =
714 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
715 [src2_typeid];
716 if (matrix_row_broadcast_fn != nullptr) {
717 int src1_itemsize = src1.get_elemsize();
718 int src2_itemsize = src2.get_elemsize();
719 int dst_itemsize = dst.get_elemsize();
720
721 if (is_aligned<required_alignment>(
722 src1_data + src1_offset * src1_itemsize) &&
723 is_aligned<required_alignment>(
724 src2_data + src2_offset * src2_itemsize) &&
725 is_aligned<required_alignment>(
726 dst_data + dst_offset * dst_itemsize)) {
727 std::size_t n0 = simplified_shape[0];
728 std::size_t n1 = simplified_shape[1];
729 sycl::event comp_ev = matrix_row_broadcast_fn(
730 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
731 src2_data, src2_offset, dst_data, dst_offset,
732 depends);
733
734 return std::make_pair(
735 dpnp::utils::keep_args_alive(
736 exec_q, {src1, src2, dst}, host_tasks),
737 comp_ev);
738 }
739 }
740 }
741 if (isEqual(simplified_src1_strides, one_zero_strides) &&
742 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
743 isEqual(simplified_dst_strides, {one, simplified_shape[0]})) {
744 auto row_matrix_broadcast_fn =
745 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
746 [src2_typeid];
747 if (row_matrix_broadcast_fn != nullptr) {
748
749 int src1_itemsize = src1.get_elemsize();
750 int src2_itemsize = src2.get_elemsize();
751 int dst_itemsize = dst.get_elemsize();
752
753 if (is_aligned<required_alignment>(
754 src1_data + src1_offset * src1_itemsize) &&
755 is_aligned<required_alignment>(
756 src2_data + src2_offset * src2_itemsize) &&
757 is_aligned<required_alignment>(
758 dst_data + dst_offset * dst_itemsize)) {
759 std::size_t n0 = simplified_shape[1];
760 std::size_t n1 = simplified_shape[0];
761 sycl::event comp_ev = row_matrix_broadcast_fn(
762 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
763 src2_data, src2_offset, dst_data, dst_offset,
764 depends);
765
766 return std::make_pair(
767 dpnp::utils::keep_args_alive(
768 exec_q, {src1, src2, dst}, host_tasks),
769 comp_ev);
770 }
771 }
772 }
773 }
774 }
775
776 // dispatch to strided code
777 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
778
779 if (strided_fn == nullptr) {
780 throw std::runtime_error(
781 "Strided implementation is missing for src1_typeid=" +
782 std::to_string(src1_typeid) +
783 " and src2_typeid=" + std::to_string(src2_typeid));
784 }
785
786 using dpnp::tensor::offset_utils::device_allocate_and_pack;
787 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
788 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
789 simplified_src2_strides, simplified_dst_strides);
790 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
791 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
792
793 const py::ssize_t *shape_strides = shape_strides_owner.get();
794
795 sycl::event strided_fn_ev = strided_fn(
796 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
797 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
798
799 // async free of shape_strides temporary
800 sycl::event tmp_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free(
801 exec_q, {strided_fn_ev}, shape_strides_owner);
802
803 host_tasks.push_back(tmp_cleanup_ev);
804
805 return std::make_pair(
806 dpnp::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
807 strided_fn_ev);
808}
809
811template <typename output_typesT>
812py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
813 const py::dtype &input2_dtype,
814 const output_typesT &output_types_table)
815{
816 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpnp
817 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpnp
818 int src1_typeid = -1;
819 int src2_typeid = -1;
820
821 auto array_types = td_ns::usm_ndarray_types();
822
823 try {
824 src1_typeid = array_types.typenum_to_lookup_id(tn1);
825 src2_typeid = array_types.typenum_to_lookup_id(tn2);
826 } catch (const std::exception &e) {
827 throw py::value_error(e.what());
828 }
829
830 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
831 src2_typeid >= td_ns::num_types) {
832 throw std::runtime_error("binary output type lookup failed");
833 }
834 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
835
836 if (dst_typeid < 0) {
837 auto res = py::none();
838 return py::cast<py::object>(res);
839 }
840 else {
841 using type_utils::_dtype_from_typenum;
842
843 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
844 auto dt = _dtype_from_typenum(dst_typenum_t);
845
846 return py::cast<py::object>(dt);
847 }
848}
849
852template <typename output_typesT,
853 typename contig_dispatchT,
854 typename strided_dispatchT>
855std::pair<sycl::event, sycl::event>
856 py_binary_two_outputs_ufunc(const dpnp::tensor::usm_ndarray &src1,
857 const dpnp::tensor::usm_ndarray &src2,
858 const dpnp::tensor::usm_ndarray &dst1,
859 const dpnp::tensor::usm_ndarray &dst2,
860 sycl::queue &exec_q,
861 const std::vector<sycl::event> &depends,
862 //
863 const output_typesT &output_types_table,
864 const contig_dispatchT &contig_dispatch_table,
865 const strided_dispatchT &strided_dispatch_table)
866{
867 // check type_nums
868 int src1_typenum = src1.get_typenum();
869 int src2_typenum = src2.get_typenum();
870 int dst1_typenum = dst1.get_typenum();
871 int dst2_typenum = dst2.get_typenum();
872
873 auto array_types = td_ns::usm_ndarray_types();
874 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
875 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
876 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
877 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
878
879 std::pair<int, int> output_typeids =
880 output_types_table[src1_typeid][src2_typeid];
881
882 if (dst1_typeid != output_typeids.first ||
883 dst2_typeid != output_typeids.second) {
884 throw py::value_error(
885 "One of destination arrays has unexpected elemental data type.");
886 }
887
888 // check that queues are compatible
889 if (!dpnp::utils::queues_are_compatible(exec_q, {src1, src2, dst1, dst2})) {
890 throw py::value_error(
891 "Execution queue is not compatible with allocation queues");
892 }
893
894 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
895 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
896
897 // check shapes, broadcasting is assumed done by caller
898 // check that dimensions are the same
899 int src1_nd = src1.get_ndim();
900 int src2_nd = src2.get_ndim();
901 int dst1_nd = dst1.get_ndim();
902 int dst2_nd = dst2.get_ndim();
903
904 if (dst1_nd != src1_nd || dst1_nd != src2_nd || dst1_nd != dst2_nd) {
905 throw py::value_error("Array dimensions are not the same.");
906 }
907
908 // check that shapes are the same
909 const py::ssize_t *src1_shape = src1.get_shape_raw();
910 const py::ssize_t *src2_shape = src2.get_shape_raw();
911 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
912 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
913 bool shapes_equal(true);
914 std::size_t src_nelems(1);
915
916 for (int i = 0; i < dst1_nd; ++i) {
917 const auto &sh_i = dst1_shape[i];
918 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
919 shapes_equal =
920 shapes_equal && (src1_shape[i] == sh_i && src2_shape[i] == sh_i &&
921 dst2_shape[i] == sh_i);
922 }
923 if (!shapes_equal) {
924 throw py::value_error("Array shapes are not the same.");
925 }
926
927 // if nelems is zero, return
928 if (src_nelems == 0) {
929 return std::make_pair(sycl::event(), sycl::event());
930 }
931
932 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst1, src_nelems);
933 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst2, src_nelems);
934
935 // check memory overlap
936 auto const &overlap = dpnp::tensor::overlap::MemoryOverlap();
937 auto const &same_logical_tensors =
938 dpnp::tensor::overlap::SameLogicalTensors();
939 if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) ||
940 (overlap(src1, dst2) && !same_logical_tensors(src1, dst2)) ||
941 (overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) ||
942 (overlap(src2, dst2) && !same_logical_tensors(src2, dst2)) ||
943 (overlap(dst1, dst2))) {
944 throw py::value_error("Arrays index overlapping segments of memory");
945 }
946
947 const char *src1_data = src1.get_data();
948 const char *src2_data = src2.get_data();
949 char *dst1_data = dst1.get_data();
950 char *dst2_data = dst2.get_data();
951
952 // handle contiguous inputs
953 bool is_src1_c_contig = src1.is_c_contiguous();
954 bool is_src1_f_contig = src1.is_f_contiguous();
955
956 bool is_src2_c_contig = src2.is_c_contiguous();
957 bool is_src2_f_contig = src2.is_f_contiguous();
958
959 bool is_dst1_c_contig = dst1.is_c_contiguous();
960 bool is_dst1_f_contig = dst1.is_f_contiguous();
961
962 bool is_dst2_c_contig = dst2.is_c_contiguous();
963 bool is_dst2_f_contig = dst2.is_f_contiguous();
964
965 bool all_c_contig = (is_src1_c_contig && is_src2_c_contig &&
966 is_dst1_c_contig && is_dst2_c_contig);
967 bool all_f_contig = (is_src1_f_contig && is_src2_f_contig &&
968 is_dst1_f_contig && is_dst2_f_contig);
969
970 // dispatch for contiguous inputs
971 if (all_c_contig || all_f_contig) {
972 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
973
974 if (contig_fn != nullptr) {
975 auto comp_ev =
976 contig_fn(exec_q, src_nelems, src1_data, 0, src2_data, 0,
977 dst1_data, 0, dst2_data, 0, depends);
978 sycl::event ht_ev = dpnp::utils::keep_args_alive(
979 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
980
981 return std::make_pair(ht_ev, comp_ev);
982 }
983 }
984
985 // simplify strides
986 auto const &src1_strides = src1.get_strides_vector();
987 auto const &src2_strides = src2.get_strides_vector();
988 auto const &dst1_strides = dst1.get_strides_vector();
989 auto const &dst2_strides = dst2.get_strides_vector();
990
991 using shT = std::vector<py::ssize_t>;
992 shT simplified_shape;
993 shT simplified_src1_strides;
994 shT simplified_src2_strides;
995 shT simplified_dst1_strides;
996 shT simplified_dst2_strides;
997 py::ssize_t src1_offset(0);
998 py::ssize_t src2_offset(0);
999 py::ssize_t dst1_offset(0);
1000 py::ssize_t dst2_offset(0);
1001
1002 int nd = dst1_nd;
1003 const py::ssize_t *shape = src1_shape;
1004
1005 simplify_iteration_space_4(
1006 nd, shape, src1_strides, src2_strides, dst1_strides, dst2_strides,
1007 // outputs
1008 simplified_shape, simplified_src1_strides, simplified_src2_strides,
1009 simplified_dst1_strides, simplified_dst2_strides, src1_offset,
1010 src2_offset, dst1_offset, dst2_offset);
1011
1012 std::vector<sycl::event> host_tasks{};
1013 static constexpr auto unit_stride = std::initializer_list<py::ssize_t>{1};
1014
1015 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
1016 isEqual(simplified_src2_strides, unit_stride) &&
1017 isEqual(simplified_dst1_strides, unit_stride) &&
1018 isEqual(simplified_dst2_strides, unit_stride)) {
1019 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
1020
1021 if (contig_fn != nullptr) {
1022 auto comp_ev =
1023 contig_fn(exec_q, src_nelems, src1_data, src1_offset, src2_data,
1024 src2_offset, dst1_data, dst1_offset, dst2_data,
1025 dst2_offset, depends);
1026 sycl::event ht_ev = dpnp::utils::keep_args_alive(
1027 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
1028
1029 return std::make_pair(ht_ev, comp_ev);
1030 }
1031 }
1032
1033 // dispatch to strided code
1034 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
1035
1036 if (strided_fn == nullptr) {
1037 throw std::runtime_error(
1038 "Strided implementation is missing for src1_typeid=" +
1039 std::to_string(src1_typeid) +
1040 " and src2_typeid=" + std::to_string(src2_typeid));
1041 }
1042
1043 using dpnp::tensor::offset_utils::device_allocate_and_pack;
1044 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1045 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
1046 simplified_src2_strides, simplified_dst1_strides,
1047 simplified_dst2_strides);
1048 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1049 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1050
1051 const py::ssize_t *shape_strides = shape_strides_owner.get();
1052
1053 sycl::event strided_fn_ev =
1054 strided_fn(exec_q, src_nelems, nd, shape_strides, src1_data,
1055 src1_offset, src2_data, src2_offset, dst1_data, dst1_offset,
1056 dst2_data, dst2_offset, depends, {copy_shape_ev});
1057
1058 // async free of shape_strides temporary
1059 sycl::event tmp_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free(
1060 exec_q, {strided_fn_ev}, shape_strides_owner);
1061 host_tasks.push_back(tmp_cleanup_ev);
1062
1063 return std::make_pair(dpnp::utils::keep_args_alive(
1064 exec_q, {src1, src2, dst1, dst2}, host_tasks),
1065 strided_fn_ev);
1066}
1067
1072template <typename output_typesT>
1073std::pair<py::object, py::object> py_binary_two_outputs_ufunc_result_type(
1074 const py::dtype &input1_dtype,
1075 const py::dtype &input2_dtype,
1076 const output_typesT &output_types_table)
1077{
1078 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpnp
1079 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpnp
1080 int src1_typeid = -1;
1081 int src2_typeid = -1;
1082
1083 auto array_types = td_ns::usm_ndarray_types();
1084
1085 try {
1086 src1_typeid = array_types.typenum_to_lookup_id(tn1);
1087 src2_typeid = array_types.typenum_to_lookup_id(tn2);
1088 } catch (const std::exception &e) {
1089 throw py::value_error(e.what());
1090 }
1091
1092 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
1093 src2_typeid >= td_ns::num_types) {
1094 throw std::runtime_error("binary output type lookup failed");
1095 }
1096 std::pair<int, int> dst_typeids =
1097 output_types_table[src1_typeid][src2_typeid];
1098 int dst1_typeid = dst_typeids.first;
1099 int dst2_typeid = dst_typeids.second;
1100
1101 if (dst1_typeid < 0 || dst2_typeid < 0) {
1102 auto res = py::none();
1103 auto py_res = py::cast<py::object>(res);
1104 return std::make_pair(py_res, py_res);
1105 }
1106 else {
1107 using type_utils::_dtype_from_typenum;
1108
1109 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
1110 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
1111 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
1112 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
1113
1114 return std::make_pair(py::cast<py::object>(dt1),
1115 py::cast<py::object>(dt2));
1116 }
1117}
1118
1119// ==================== Inplace binary functions =======================
1120
1121template <typename output_typesT,
1122 typename contig_dispatchT,
1123 typename strided_dispatchT,
1124 typename contig_row_matrix_dispatchT>
1125std::pair<sycl::event, sycl::event>
1126 py_binary_inplace_ufunc(const dpnp::tensor::usm_ndarray &lhs,
1127 const dpnp::tensor::usm_ndarray &rhs,
1128 sycl::queue &exec_q,
1129 const std::vector<sycl::event> &depends,
1130 //
1131 const output_typesT &output_type_table,
1132 const contig_dispatchT &contig_dispatch_table,
1133 const strided_dispatchT &strided_dispatch_table,
1134 const contig_row_matrix_dispatchT
1135 &contig_row_matrix_broadcast_dispatch_table)
1136{
1137 dpnp::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
1138
1139 // check type_nums
1140 int rhs_typenum = rhs.get_typenum();
1141 int lhs_typenum = lhs.get_typenum();
1142
1143 auto array_types = td_ns::usm_ndarray_types();
1144 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
1145 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
1146
1147 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
1148
1149 if (output_typeid != lhs_typeid) {
1150 throw py::value_error(
1151 "Left-hand side array has unexpected elemental data type.");
1152 }
1153
1154 // check that queues are compatible
1155 if (!dpnp::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
1156 throw py::value_error(
1157 "Execution queue is not compatible with allocation queues");
1158 }
1159
1160 // check shapes, broadcasting is assumed done by caller
1161 // check that dimensions are the same
1162 int lhs_nd = lhs.get_ndim();
1163 if (lhs_nd != rhs.get_ndim()) {
1164 throw py::value_error("Array dimensions are not the same.");
1165 }
1166
1167 // check that shapes are the same
1168 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
1169 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
1170 bool shapes_equal(true);
1171 std::size_t rhs_nelems(1);
1172
1173 for (int i = 0; i < lhs_nd; ++i) {
1174 rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
1175 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
1176 }
1177 if (!shapes_equal) {
1178 throw py::value_error("Array shapes are not the same.");
1179 }
1180
1181 // if nelems is zero, return
1182 if (rhs_nelems == 0) {
1183 return std::make_pair(sycl::event(), sycl::event());
1184 }
1185
1186 dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
1187
1188 // check memory overlap
1189 auto const &same_logical_tensors =
1190 dpnp::tensor::overlap::SameLogicalTensors();
1191 auto const &overlap = dpnp::tensor::overlap::MemoryOverlap();
1192 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
1193 throw py::value_error("Arrays index overlapping segments of memory");
1194 }
1195 // check memory overlap
1196 const char *rhs_data = rhs.get_data();
1197 char *lhs_data = lhs.get_data();
1198
1199 // handle contiguous inputs
1200 bool is_rhs_c_contig = rhs.is_c_contiguous();
1201 bool is_rhs_f_contig = rhs.is_f_contiguous();
1202
1203 bool is_lhs_c_contig = lhs.is_c_contiguous();
1204 bool is_lhs_f_contig = lhs.is_f_contiguous();
1205
1206 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
1207 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
1208
1209 // dispatch for contiguous inputs
1210 if (both_c_contig || both_f_contig) {
1211 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1212
1213 if (contig_fn != nullptr) {
1214 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
1215 0, depends);
1216 sycl::event ht_ev =
1217 dpnp::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
1218
1219 return std::make_pair(ht_ev, comp_ev);
1220 }
1221 }
1222
1223 // simplify strides
1224 auto const &rhs_strides = rhs.get_strides_vector();
1225 auto const &lhs_strides = lhs.get_strides_vector();
1226
1227 using shT = std::vector<py::ssize_t>;
1228 shT simplified_shape;
1229 shT simplified_rhs_strides;
1230 shT simplified_lhs_strides;
1231 py::ssize_t rhs_offset(0);
1232 py::ssize_t lhs_offset(0);
1233
1234 int nd = lhs_nd;
1235 const py::ssize_t *shape = rhs_shape;
1236
1237 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
1238 // outputs
1239 simplified_shape, simplified_rhs_strides,
1240 simplified_lhs_strides, rhs_offset, lhs_offset);
1241
1242 std::vector<sycl::event> host_tasks{};
1243 if (nd < 3) {
1244 static constexpr auto unit_stride =
1245 std::initializer_list<py::ssize_t>{1};
1246
1247 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
1248 isEqual(simplified_lhs_strides, unit_stride)) {
1249 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1250
1251 if (contig_fn != nullptr) {
1252 auto comp_ev =
1253 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
1254 lhs_data, lhs_offset, depends);
1255 sycl::event ht_ev =
1256 dpnp::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
1257
1258 return std::make_pair(ht_ev, comp_ev);
1259 }
1260 }
1261 if (nd == 2) {
1262 static constexpr auto one_zero_strides =
1263 std::initializer_list<py::ssize_t>{1, 0};
1264 static constexpr py::ssize_t one{1};
1265 // special case of C-contiguous matrix and a row
1266 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
1267 isEqual(simplified_lhs_strides, {one, simplified_shape[0]})) {
1268 auto row_matrix_broadcast_fn =
1269 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
1270 [lhs_typeid];
1271 if (row_matrix_broadcast_fn != nullptr) {
1272 std::size_t n0 = simplified_shape[1];
1273 std::size_t n1 = simplified_shape[0];
1274 sycl::event comp_ev = row_matrix_broadcast_fn(
1275 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
1276 lhs_data, lhs_offset, depends);
1277
1278 return std::make_pair(dpnp::utils::keep_args_alive(
1279 exec_q, {lhs, rhs}, host_tasks),
1280 comp_ev);
1281 }
1282 }
1283 }
1284 }
1285
1286 // dispatch to strided code
1287 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
1288
1289 if (strided_fn == nullptr) {
1290 throw std::runtime_error(
1291 "Strided implementation is missing for rhs_typeid=" +
1292 std::to_string(rhs_typeid) +
1293 " and lhs_typeid=" + std::to_string(lhs_typeid));
1294 }
1295
1296 using dpnp::tensor::offset_utils::device_allocate_and_pack;
1297 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1298 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
1299 simplified_lhs_strides);
1300 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1301 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1302
1303 const py::ssize_t *shape_strides = shape_strides_owner.get();
1304
1305 sycl::event strided_fn_ev =
1306 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
1307 lhs_data, lhs_offset, depends, {copy_shape_ev});
1308
1309 // async free of shape_strides temporary
1310 sycl::event tmp_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free(
1311 exec_q, {strided_fn_ev}, shape_strides_owner);
1312
1313 host_tasks.push_back(tmp_cleanup_ev);
1314
1315 return std::make_pair(
1316 dpnp::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
1317 strided_fn_ev);
1318}
1319} // namespace dpnp::extensions::py_internal