DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <exception>
33#include <iterator>
34#include <stdexcept>
35#include <utility>
36#include <vector>
37
38#include <pybind11/numpy.h>
39#include <pybind11/pybind11.h>
40
41#include <sycl/sycl.hpp>
42
43#include "dpnp4pybind11.hpp"
44
45#include "elementwise_functions_type_utils.hpp"
46#include "simplify_iteration_space.hpp"
47
48// dpctl tensor headers
49#include "kernels/alignment.hpp"
50#include "utils/memory_overlap.hpp"
51#include "utils/offset_utils.hpp"
52#include "utils/output_validation.hpp"
53#include "utils/sycl_alloc_utils.hpp"
54#include "utils/type_dispatch.hpp"
55
56static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
57
58namespace dpnp::extensions::py_internal
59{
60namespace py = pybind11;
61namespace td_ns = dpctl::tensor::type_dispatch;
62
63using dpctl::tensor::kernels::alignment_utils::is_aligned;
64using dpctl::tensor::kernels::alignment_utils::required_alignment;
65
66using type_utils::_result_typeid;
67
69template <typename output_typesT,
70 typename contig_dispatchT,
71 typename strided_dispatchT>
72std::pair<sycl::event, sycl::event>
73 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
75 sycl::queue &q,
76 const std::vector<sycl::event> &depends,
77 //
78 const output_typesT &output_type_vec,
79 const contig_dispatchT &contig_dispatch_vector,
80 const strided_dispatchT &strided_dispatch_vector)
81{
82 int src_typenum = src.get_typenum();
83 int dst_typenum = dst.get_typenum();
84
85 const auto &array_types = td_ns::usm_ndarray_types();
86 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
87 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
88
89 int func_output_typeid = output_type_vec[src_typeid];
90
91 // check that types are supported
92 if (dst_typeid != func_output_typeid) {
93 throw py::value_error(
94 "Destination array has unexpected elemental data type.");
95 }
96
97 // check that queues are compatible
98 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
99 throw py::value_error(
100 "Execution queue is not compatible with allocation queues");
101 }
102
103 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
104
105 // check that dimensions are the same
106 int src_nd = src.get_ndim();
107 if (src_nd != dst.get_ndim()) {
108 throw py::value_error("Array dimensions are not the same.");
109 }
110
111 // check that shapes are the same
112 const py::ssize_t *src_shape = src.get_shape_raw();
113 const py::ssize_t *dst_shape = dst.get_shape_raw();
114 bool shapes_equal(true);
115 std::size_t src_nelems(1);
116
117 for (int i = 0; i < src_nd; ++i) {
118 src_nelems *= static_cast<std::size_t>(src_shape[i]);
119 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
120 }
121 if (!shapes_equal) {
122 throw py::value_error("Array shapes are not the same.");
123 }
124
125 // if nelems is zero, return
126 if (src_nelems == 0) {
127 return std::make_pair(sycl::event(), sycl::event());
128 }
129
130 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
131
132 // check memory overlap
133 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
134 auto const &same_logical_tensors =
135 dpctl::tensor::overlap::SameLogicalTensors();
136 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
137 throw py::value_error("Arrays index overlapping segments of memory");
138 }
139
140 const char *src_data = src.get_data();
141 char *dst_data = dst.get_data();
142
143 // handle contiguous inputs
144 bool is_src_c_contig = src.is_c_contiguous();
145 bool is_src_f_contig = src.is_f_contiguous();
146
147 bool is_dst_c_contig = dst.is_c_contiguous();
148 bool is_dst_f_contig = dst.is_f_contiguous();
149
150 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
151 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
152
153 if (both_c_contig || both_f_contig) {
154 auto contig_fn = contig_dispatch_vector[src_typeid];
155
156 if (contig_fn == nullptr) {
157 throw std::runtime_error(
158 "Contiguous implementation is missing for src_typeid=" +
159 std::to_string(src_typeid));
160 }
161
162 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
163 sycl::event ht_ev =
164 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
165
166 return std::make_pair(ht_ev, comp_ev);
167 }
168
169 // simplify iteration space
170 // if 1d with strides 1 - input is contig
171 // dispatch to strided
172
173 auto const &src_strides = src.get_strides_vector();
174 auto const &dst_strides = dst.get_strides_vector();
175
176 using shT = std::vector<py::ssize_t>;
177 shT simplified_shape;
178 shT simplified_src_strides;
179 shT simplified_dst_strides;
180 py::ssize_t src_offset(0);
181 py::ssize_t dst_offset(0);
182
183 int nd = src_nd;
184 const py::ssize_t *shape = src_shape;
185
186 simplify_iteration_space(nd, shape, src_strides, dst_strides,
187 // output
188 simplified_shape, simplified_src_strides,
189 simplified_dst_strides, src_offset, dst_offset);
190
191 if (nd == 1 && simplified_src_strides[0] == 1 &&
192 simplified_dst_strides[0] == 1) {
193 // Special case of contiguous data
194 auto contig_fn = contig_dispatch_vector[src_typeid];
195
196 if (contig_fn == nullptr) {
197 throw std::runtime_error(
198 "Contiguous implementation is missing for src_typeid=" +
199 std::to_string(src_typeid));
200 }
201
202 int src_elem_size = src.get_elemsize();
203 int dst_elem_size = dst.get_elemsize();
204 auto comp_ev =
205 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
206 dst_data + dst_elem_size * dst_offset, depends);
207
208 sycl::event ht_ev =
209 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
210
211 return std::make_pair(ht_ev, comp_ev);
212 }
213
214 // Strided implementation
215 auto strided_fn = strided_dispatch_vector[src_typeid];
216
217 if (strided_fn == nullptr) {
218 throw std::runtime_error(
219 "Strided implementation is missing for src_typeid=" +
220 std::to_string(src_typeid));
221 }
222
223 using dpctl::tensor::offset_utils::device_allocate_and_pack;
224
225 std::vector<sycl::event> host_tasks{};
226 host_tasks.reserve(2);
227
228 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
229 q, host_tasks, simplified_shape, simplified_src_strides,
230 simplified_dst_strides);
231 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
232 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
233 const py::ssize_t *shape_strides = shape_strides_owner.get();
234
235 sycl::event strided_fn_ev =
236 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
237 dst_data, dst_offset, depends, {copy_shape_ev});
238
239 // async free of shape_strides temporary
240 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
241 q, {strided_fn_ev}, shape_strides_owner);
242
243 host_tasks.push_back(tmp_cleanup_ev);
244
245 return std::make_pair(
246 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
247 strided_fn_ev);
248}
249
252template <typename output_typesT>
253py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
254 const output_typesT &output_types)
255{
256 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
257 int src_typeid = -1;
258
259 auto array_types = td_ns::usm_ndarray_types();
260
261 try {
262 src_typeid = array_types.typenum_to_lookup_id(tn);
263 } catch (const std::exception &e) {
264 throw py::value_error(e.what());
265 }
266
267 int dst_typeid = _result_typeid(src_typeid, output_types);
268 if (dst_typeid < 0) {
269 auto res = py::none();
270 return py::cast<py::object>(res);
271 }
272 else {
273 using type_utils::_dtype_from_typenum;
274
275 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
276 auto dt = _dtype_from_typenum(dst_typenum_t);
277
278 return py::cast<py::object>(dt);
279 }
280}
281
286template <typename output_typesT,
287 typename contig_dispatchT,
288 typename strided_dispatchT>
289std::pair<sycl::event, sycl::event>
290 py_unary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src,
291 const dpctl::tensor::usm_ndarray &dst1,
292 const dpctl::tensor::usm_ndarray &dst2,
293 sycl::queue &q,
294 const std::vector<sycl::event> &depends,
295 //
296 const output_typesT &output_type_vec,
297 const contig_dispatchT &contig_dispatch_vector,
298 const strided_dispatchT &strided_dispatch_vector)
299{
300 int src_typenum = src.get_typenum();
301 int dst1_typenum = dst1.get_typenum();
302 int dst2_typenum = dst2.get_typenum();
303
304 const auto &array_types = td_ns::usm_ndarray_types();
305 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
306 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
307 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
308
309 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
310
311 // check that types are supported
312 if (dst1_typeid != func_output_typeids.first ||
313 dst2_typeid != func_output_typeids.second) {
314 throw py::value_error(
315 "One of destination arrays has unexpected elemental data type.");
316 }
317
318 // check that queues are compatible
319 if (!dpctl::utils::queues_are_compatible(q, {src, dst1, dst2})) {
320 throw py::value_error(
321 "Execution queue is not compatible with allocation queues");
322 }
323
324 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
325 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
326
327 // check that dimensions are the same
328 int src_nd = src.get_ndim();
329 if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
330 throw py::value_error("Array dimensions are not the same.");
331 }
332
333 // check that shapes are the same
334 const py::ssize_t *src_shape = src.get_shape_raw();
335 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
336 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
337 bool shapes_equal(true);
338 std::size_t src_nelems(1);
339
340 for (int i = 0; i < src_nd; ++i) {
341 src_nelems *= static_cast<std::size_t>(src_shape[i]);
342 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
343 (src_shape[i] == dst2_shape[i]);
344 }
345 if (!shapes_equal) {
346 throw py::value_error("Array shapes are not the same.");
347 }
348
349 // if nelems is zero, return
350 if (src_nelems == 0) {
351 return std::make_pair(sycl::event(), sycl::event());
352 }
353
354 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
355 src_nelems);
356 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
357 src_nelems);
358
359 // check memory overlap
360 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
361 auto const &same_logical_tensors =
362 dpctl::tensor::overlap::SameLogicalTensors();
363 if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
364 (overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
365 (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2))) {
366 throw py::value_error("Arrays index overlapping segments of memory");
367 }
368
369 const char *src_data = src.get_data();
370 char *dst1_data = dst1.get_data();
371 char *dst2_data = dst2.get_data();
372
373 // handle contiguous inputs
374 bool is_src_c_contig = src.is_c_contiguous();
375 bool is_src_f_contig = src.is_f_contiguous();
376
377 bool is_dst1_c_contig = dst1.is_c_contiguous();
378 bool is_dst1_f_contig = dst1.is_f_contiguous();
379
380 bool is_dst2_c_contig = dst2.is_c_contiguous();
381 bool is_dst2_f_contig = dst2.is_f_contiguous();
382
383 bool all_c_contig =
384 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
385 bool all_f_contig =
386 (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
387
388 if (all_c_contig || all_f_contig) {
389 auto contig_fn = contig_dispatch_vector[src_typeid];
390
391 if (contig_fn == nullptr) {
392 throw std::runtime_error(
393 "Contiguous implementation is missing for src_typeid=" +
394 std::to_string(src_typeid));
395 }
396
397 auto comp_ev =
398 contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
399 sycl::event ht_ev =
400 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
401
402 return std::make_pair(ht_ev, comp_ev);
403 }
404
405 // simplify iteration space
406 // if 1d with strides 1 - input is contig
407 // dispatch to strided
408
409 auto const &src_strides = src.get_strides_vector();
410 auto const &dst1_strides = dst1.get_strides_vector();
411 auto const &dst2_strides = dst2.get_strides_vector();
412
413 using shT = std::vector<py::ssize_t>;
414 shT simplified_shape;
415 shT simplified_src_strides;
416 shT simplified_dst1_strides;
417 shT simplified_dst2_strides;
418 py::ssize_t src_offset(0);
419 py::ssize_t dst1_offset(0);
420 py::ssize_t dst2_offset(0);
421
422 int nd = src_nd;
423 const py::ssize_t *shape = src_shape;
424
425 simplify_iteration_space_3(
426 nd, shape, src_strides, dst1_strides, dst2_strides,
427 // output
428 simplified_shape, simplified_src_strides, simplified_dst1_strides,
429 simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
430
431 if (nd == 1 && simplified_src_strides[0] == 1 &&
432 simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1) {
433 // Special case of contiguous data
434 auto contig_fn = contig_dispatch_vector[src_typeid];
435
436 if (contig_fn == nullptr) {
437 throw std::runtime_error(
438 "Contiguous implementation is missing for src_typeid=" +
439 std::to_string(src_typeid));
440 }
441
442 int src_elem_size = src.get_elemsize();
443 int dst1_elem_size = dst1.get_elemsize();
444 int dst2_elem_size = dst2.get_elemsize();
445 auto comp_ev =
446 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
447 dst1_data + dst1_elem_size * dst1_offset,
448 dst2_data + dst2_elem_size * dst2_offset, depends);
449
450 sycl::event ht_ev =
451 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
452
453 return std::make_pair(ht_ev, comp_ev);
454 }
455
456 // Strided implementation
457 auto strided_fn = strided_dispatch_vector[src_typeid];
458
459 if (strided_fn == nullptr) {
460 throw std::runtime_error(
461 "Strided implementation is missing for src_typeid=" +
462 std::to_string(src_typeid));
463 }
464
465 using dpctl::tensor::offset_utils::device_allocate_and_pack;
466
467 std::vector<sycl::event> host_tasks{};
468 host_tasks.reserve(2);
469
470 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
471 q, host_tasks, simplified_shape, simplified_src_strides,
472 simplified_dst1_strides, simplified_dst2_strides);
473 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
474 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
475 const py::ssize_t *shape_strides = shape_strides_owner.get();
476
477 sycl::event strided_fn_ev = strided_fn(
478 q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
479 dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
480
481 // async free of shape_strides temporary
482 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
483 q, {strided_fn_ev}, shape_strides_owner);
484
485 host_tasks.push_back(tmp_cleanup_ev);
486
487 return std::make_pair(
488 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
489 strided_fn_ev);
490}
491
496template <typename output_typesT>
497std::pair<py::object, py::object>
498 py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
499 const output_typesT &output_types)
500{
501 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
502 int src_typeid = -1;
503
504 auto array_types = td_ns::usm_ndarray_types();
505
506 try {
507 src_typeid = array_types.typenum_to_lookup_id(tn);
508 } catch (const std::exception &e) {
509 throw py::value_error(e.what());
510 }
511
512 std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
513 int dst1_typeid = dst_typeids.first;
514 int dst2_typeid = dst_typeids.second;
515
516 if (dst1_typeid < 0 || dst2_typeid < 0) {
517 auto res = py::none();
518 auto py_res = py::cast<py::object>(res);
519 return std::make_pair(py_res, py_res);
520 }
521 else {
522 using type_utils::_dtype_from_typenum;
523
524 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
525 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
526 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
527 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
528
529 return std::make_pair(py::cast<py::object>(dt1),
530 py::cast<py::object>(dt2));
531 }
532}
533
534// ======================== Binary functions ===========================
535
536namespace
537{
538template <class Container, class T>
539bool isEqual(Container const &c, std::initializer_list<T> const &l)
540{
541 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
542}
543} // namespace
544
547template <typename output_typesT,
548 typename contig_dispatchT,
549 typename strided_dispatchT,
550 typename contig_matrix_row_dispatchT,
551 typename contig_row_matrix_dispatchT>
552std::pair<sycl::event, sycl::event> py_binary_ufunc(
553 const dpctl::tensor::usm_ndarray &src1,
554 const dpctl::tensor::usm_ndarray &src2,
555 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
556 sycl::queue &exec_q,
557 const std::vector<sycl::event> &depends,
558 //
559 const output_typesT &output_type_table,
560 const contig_dispatchT &contig_dispatch_table,
561 const strided_dispatchT &strided_dispatch_table,
562 const contig_matrix_row_dispatchT
563 &contig_matrix_row_broadcast_dispatch_table,
564 const contig_row_matrix_dispatchT
565 &contig_row_matrix_broadcast_dispatch_table)
566{
567 // check type_nums
568 int src1_typenum = src1.get_typenum();
569 int src2_typenum = src2.get_typenum();
570 int dst_typenum = dst.get_typenum();
571
572 auto array_types = td_ns::usm_ndarray_types();
573 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
574 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
575 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
576
577 int output_typeid = output_type_table[src1_typeid][src2_typeid];
578
579 if (output_typeid != dst_typeid) {
580 throw py::value_error(
581 "Destination array has unexpected elemental data type.");
582 }
583
584 // check that queues are compatible
585 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
586 throw py::value_error(
587 "Execution queue is not compatible with allocation queues");
588 }
589
590 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
591
592 // check shapes, broadcasting is assumed done by caller
593 // check that dimensions are the same
594 int dst_nd = dst.get_ndim();
595 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
596 throw py::value_error("Array dimensions are not the same.");
597 }
598
599 // check that shapes are the same
600 const py::ssize_t *src1_shape = src1.get_shape_raw();
601 const py::ssize_t *src2_shape = src2.get_shape_raw();
602 const py::ssize_t *dst_shape = dst.get_shape_raw();
603 bool shapes_equal(true);
604 std::size_t src_nelems(1);
605
606 for (int i = 0; i < dst_nd; ++i) {
607 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
608 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
609 src2_shape[i] == dst_shape[i]);
610 }
611 if (!shapes_equal) {
612 throw py::value_error("Array shapes are not the same.");
613 }
614
615 // if nelems is zero, return
616 if (src_nelems == 0) {
617 return std::make_pair(sycl::event(), sycl::event());
618 }
619
620 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
621
622 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
623 auto const &same_logical_tensors =
624 dpctl::tensor::overlap::SameLogicalTensors();
625 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
626 (overlap(src2, dst) && !same_logical_tensors(src2, dst))) {
627 throw py::value_error("Arrays index overlapping segments of memory");
628 }
629 // check memory overlap
630 const char *src1_data = src1.get_data();
631 const char *src2_data = src2.get_data();
632 char *dst_data = dst.get_data();
633
634 // handle contiguous inputs
635 bool is_src1_c_contig = src1.is_c_contiguous();
636 bool is_src1_f_contig = src1.is_f_contiguous();
637
638 bool is_src2_c_contig = src2.is_c_contiguous();
639 bool is_src2_f_contig = src2.is_f_contiguous();
640
641 bool is_dst_c_contig = dst.is_c_contiguous();
642 bool is_dst_f_contig = dst.is_f_contiguous();
643
644 bool all_c_contig =
645 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
646 bool all_f_contig =
647 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
648
649 // dispatch for contiguous inputs
650 if (all_c_contig || all_f_contig) {
651 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
652
653 if (contig_fn != nullptr) {
654 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
655 src2_data, 0, dst_data, 0, depends);
656 sycl::event ht_ev = dpctl::utils::keep_args_alive(
657 exec_q, {src1, src2, dst}, {comp_ev});
658
659 return std::make_pair(ht_ev, comp_ev);
660 }
661 }
662
663 // simplify strides
664 auto const &src1_strides = src1.get_strides_vector();
665 auto const &src2_strides = src2.get_strides_vector();
666 auto const &dst_strides = dst.get_strides_vector();
667
668 using shT = std::vector<py::ssize_t>;
669 shT simplified_shape;
670 shT simplified_src1_strides;
671 shT simplified_src2_strides;
672 shT simplified_dst_strides;
673 py::ssize_t src1_offset(0);
674 py::ssize_t src2_offset(0);
675 py::ssize_t dst_offset(0);
676
677 int nd = dst_nd;
678 const py::ssize_t *shape = src1_shape;
679
680 simplify_iteration_space_3(
681 nd, shape, src1_strides, src2_strides, dst_strides,
682 // outputs
683 simplified_shape, simplified_src1_strides, simplified_src2_strides,
684 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
685
686 std::vector<sycl::event> host_tasks{};
687 if (nd < 3) {
688 static constexpr auto unit_stride =
689 std::initializer_list<py::ssize_t>{1};
690
691 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
692 isEqual(simplified_src2_strides, unit_stride) &&
693 isEqual(simplified_dst_strides, unit_stride)) {
694 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
695
696 if (contig_fn != nullptr) {
697 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
698 src1_offset, src2_data, src2_offset,
699 dst_data, dst_offset, depends);
700 sycl::event ht_ev = dpctl::utils::keep_args_alive(
701 exec_q, {src1, src2, dst}, {comp_ev});
702
703 return std::make_pair(ht_ev, comp_ev);
704 }
705 }
706 if (nd == 2) {
707 static constexpr auto zero_one_strides =
708 std::initializer_list<py::ssize_t>{0, 1};
709 static constexpr auto one_zero_strides =
710 std::initializer_list<py::ssize_t>{1, 0};
711 static constexpr py::ssize_t one{1};
712 // special case of C-contiguous matrix and a row
713 if (isEqual(simplified_src2_strides, zero_one_strides) &&
714 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
715 isEqual(simplified_dst_strides, {simplified_shape[1], one})) {
716 auto matrix_row_broadcast_fn =
717 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
718 [src2_typeid];
719 if (matrix_row_broadcast_fn != nullptr) {
720 int src1_itemsize = src1.get_elemsize();
721 int src2_itemsize = src2.get_elemsize();
722 int dst_itemsize = dst.get_elemsize();
723
724 if (is_aligned<required_alignment>(
725 src1_data + src1_offset * src1_itemsize) &&
726 is_aligned<required_alignment>(
727 src2_data + src2_offset * src2_itemsize) &&
728 is_aligned<required_alignment>(
729 dst_data + dst_offset * dst_itemsize)) {
730 std::size_t n0 = simplified_shape[0];
731 std::size_t n1 = simplified_shape[1];
732 sycl::event comp_ev = matrix_row_broadcast_fn(
733 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
734 src2_data, src2_offset, dst_data, dst_offset,
735 depends);
736
737 return std::make_pair(
738 dpctl::utils::keep_args_alive(
739 exec_q, {src1, src2, dst}, host_tasks),
740 comp_ev);
741 }
742 }
743 }
744 if (isEqual(simplified_src1_strides, one_zero_strides) &&
745 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
746 isEqual(simplified_dst_strides, {one, simplified_shape[0]})) {
747 auto row_matrix_broadcast_fn =
748 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
749 [src2_typeid];
750 if (row_matrix_broadcast_fn != nullptr) {
751
752 int src1_itemsize = src1.get_elemsize();
753 int src2_itemsize = src2.get_elemsize();
754 int dst_itemsize = dst.get_elemsize();
755
756 if (is_aligned<required_alignment>(
757 src1_data + src1_offset * src1_itemsize) &&
758 is_aligned<required_alignment>(
759 src2_data + src2_offset * src2_itemsize) &&
760 is_aligned<required_alignment>(
761 dst_data + dst_offset * dst_itemsize)) {
762 std::size_t n0 = simplified_shape[1];
763 std::size_t n1 = simplified_shape[0];
764 sycl::event comp_ev = row_matrix_broadcast_fn(
765 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
766 src2_data, src2_offset, dst_data, dst_offset,
767 depends);
768
769 return std::make_pair(
770 dpctl::utils::keep_args_alive(
771 exec_q, {src1, src2, dst}, host_tasks),
772 comp_ev);
773 }
774 }
775 }
776 }
777 }
778
779 // dispatch to strided code
780 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
781
782 if (strided_fn == nullptr) {
783 throw std::runtime_error(
784 "Strided implementation is missing for src1_typeid=" +
785 std::to_string(src1_typeid) +
786 " and src2_typeid=" + std::to_string(src2_typeid));
787 }
788
789 using dpctl::tensor::offset_utils::device_allocate_and_pack;
790 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
791 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
792 simplified_src2_strides, simplified_dst_strides);
793 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
794 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
795
796 const py::ssize_t *shape_strides = shape_strides_owner.get();
797
798 sycl::event strided_fn_ev = strided_fn(
799 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
800 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
801
802 // async free of shape_strides temporary
803 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
804 exec_q, {strided_fn_ev}, shape_strides_owner);
805
806 host_tasks.push_back(tmp_cleanup_ev);
807
808 return std::make_pair(
809 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
810 strided_fn_ev);
811}
812
814template <typename output_typesT>
815py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
816 const py::dtype &input2_dtype,
817 const output_typesT &output_types_table)
818{
819 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
820 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
821 int src1_typeid = -1;
822 int src2_typeid = -1;
823
824 auto array_types = td_ns::usm_ndarray_types();
825
826 try {
827 src1_typeid = array_types.typenum_to_lookup_id(tn1);
828 src2_typeid = array_types.typenum_to_lookup_id(tn2);
829 } catch (const std::exception &e) {
830 throw py::value_error(e.what());
831 }
832
833 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
834 src2_typeid >= td_ns::num_types) {
835 throw std::runtime_error("binary output type lookup failed");
836 }
837 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
838
839 if (dst_typeid < 0) {
840 auto res = py::none();
841 return py::cast<py::object>(res);
842 }
843 else {
844 using type_utils::_dtype_from_typenum;
845
846 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
847 auto dt = _dtype_from_typenum(dst_typenum_t);
848
849 return py::cast<py::object>(dt);
850 }
851}
852
855template <typename output_typesT,
856 typename contig_dispatchT,
857 typename strided_dispatchT>
858std::pair<sycl::event, sycl::event>
859 py_binary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src1,
860 const dpctl::tensor::usm_ndarray &src2,
861 const dpctl::tensor::usm_ndarray &dst1,
862 const dpctl::tensor::usm_ndarray &dst2,
863 sycl::queue &exec_q,
864 const std::vector<sycl::event> &depends,
865 //
866 const output_typesT &output_types_table,
867 const contig_dispatchT &contig_dispatch_table,
868 const strided_dispatchT &strided_dispatch_table)
869{
870 // check type_nums
871 int src1_typenum = src1.get_typenum();
872 int src2_typenum = src2.get_typenum();
873 int dst1_typenum = dst1.get_typenum();
874 int dst2_typenum = dst2.get_typenum();
875
876 auto array_types = td_ns::usm_ndarray_types();
877 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
878 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
879 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
880 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
881
882 std::pair<int, int> output_typeids =
883 output_types_table[src1_typeid][src2_typeid];
884
885 if (dst1_typeid != output_typeids.first ||
886 dst2_typeid != output_typeids.second) {
887 throw py::value_error(
888 "One of destination arrays has unexpected elemental data type.");
889 }
890
891 // check that queues are compatible
892 if (!dpctl::utils::queues_are_compatible(exec_q,
893 {src1, src2, dst1, dst2})) {
894 throw py::value_error(
895 "Execution queue is not compatible with allocation queues");
896 }
897
898 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
899 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
900
901 // check shapes, broadcasting is assumed done by caller
902 // check that dimensions are the same
903 int src1_nd = src1.get_ndim();
904 int src2_nd = src2.get_ndim();
905 int dst1_nd = dst1.get_ndim();
906 int dst2_nd = dst2.get_ndim();
907
908 if (dst1_nd != src1_nd || dst1_nd != src2_nd || dst1_nd != dst2_nd) {
909 throw py::value_error("Array dimensions are not the same.");
910 }
911
912 // check that shapes are the same
913 const py::ssize_t *src1_shape = src1.get_shape_raw();
914 const py::ssize_t *src2_shape = src2.get_shape_raw();
915 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
916 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
917 bool shapes_equal(true);
918 std::size_t src_nelems(1);
919
920 for (int i = 0; i < dst1_nd; ++i) {
921 const auto &sh_i = dst1_shape[i];
922 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
923 shapes_equal =
924 shapes_equal && (src1_shape[i] == sh_i && src2_shape[i] == sh_i &&
925 dst2_shape[i] == sh_i);
926 }
927 if (!shapes_equal) {
928 throw py::value_error("Array shapes are not the same.");
929 }
930
931 // if nelems is zero, return
932 if (src_nelems == 0) {
933 return std::make_pair(sycl::event(), sycl::event());
934 }
935
936 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
937 src_nelems);
938 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
939 src_nelems);
940
941 // check memory overlap
942 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
943 auto const &same_logical_tensors =
944 dpctl::tensor::overlap::SameLogicalTensors();
945 if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) ||
946 (overlap(src1, dst2) && !same_logical_tensors(src1, dst2)) ||
947 (overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) ||
948 (overlap(src2, dst2) && !same_logical_tensors(src2, dst2)) ||
949 (overlap(dst1, dst2))) {
950 throw py::value_error("Arrays index overlapping segments of memory");
951 }
952
953 const char *src1_data = src1.get_data();
954 const char *src2_data = src2.get_data();
955 char *dst1_data = dst1.get_data();
956 char *dst2_data = dst2.get_data();
957
958 // handle contiguous inputs
959 bool is_src1_c_contig = src1.is_c_contiguous();
960 bool is_src1_f_contig = src1.is_f_contiguous();
961
962 bool is_src2_c_contig = src2.is_c_contiguous();
963 bool is_src2_f_contig = src2.is_f_contiguous();
964
965 bool is_dst1_c_contig = dst1.is_c_contiguous();
966 bool is_dst1_f_contig = dst1.is_f_contiguous();
967
968 bool is_dst2_c_contig = dst2.is_c_contiguous();
969 bool is_dst2_f_contig = dst2.is_f_contiguous();
970
971 bool all_c_contig = (is_src1_c_contig && is_src2_c_contig &&
972 is_dst1_c_contig && is_dst2_c_contig);
973 bool all_f_contig = (is_src1_f_contig && is_src2_f_contig &&
974 is_dst1_f_contig && is_dst2_f_contig);
975
976 // dispatch for contiguous inputs
977 if (all_c_contig || all_f_contig) {
978 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
979
980 if (contig_fn != nullptr) {
981 auto comp_ev =
982 contig_fn(exec_q, src_nelems, src1_data, 0, src2_data, 0,
983 dst1_data, 0, dst2_data, 0, depends);
984 sycl::event ht_ev = dpctl::utils::keep_args_alive(
985 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
986
987 return std::make_pair(ht_ev, comp_ev);
988 }
989 }
990
991 // simplify strides
992 auto const &src1_strides = src1.get_strides_vector();
993 auto const &src2_strides = src2.get_strides_vector();
994 auto const &dst1_strides = dst1.get_strides_vector();
995 auto const &dst2_strides = dst2.get_strides_vector();
996
997 using shT = std::vector<py::ssize_t>;
998 shT simplified_shape;
999 shT simplified_src1_strides;
1000 shT simplified_src2_strides;
1001 shT simplified_dst1_strides;
1002 shT simplified_dst2_strides;
1003 py::ssize_t src1_offset(0);
1004 py::ssize_t src2_offset(0);
1005 py::ssize_t dst1_offset(0);
1006 py::ssize_t dst2_offset(0);
1007
1008 int nd = dst1_nd;
1009 const py::ssize_t *shape = src1_shape;
1010
1011 simplify_iteration_space_4(
1012 nd, shape, src1_strides, src2_strides, dst1_strides, dst2_strides,
1013 // outputs
1014 simplified_shape, simplified_src1_strides, simplified_src2_strides,
1015 simplified_dst1_strides, simplified_dst2_strides, src1_offset,
1016 src2_offset, dst1_offset, dst2_offset);
1017
1018 std::vector<sycl::event> host_tasks{};
1019 static constexpr auto unit_stride = std::initializer_list<py::ssize_t>{1};
1020
1021 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
1022 isEqual(simplified_src2_strides, unit_stride) &&
1023 isEqual(simplified_dst1_strides, unit_stride) &&
1024 isEqual(simplified_dst2_strides, unit_stride)) {
1025 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
1026
1027 if (contig_fn != nullptr) {
1028 auto comp_ev =
1029 contig_fn(exec_q, src_nelems, src1_data, src1_offset, src2_data,
1030 src2_offset, dst1_data, dst1_offset, dst2_data,
1031 dst2_offset, depends);
1032 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1033 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
1034
1035 return std::make_pair(ht_ev, comp_ev);
1036 }
1037 }
1038
1039 // dispatch to strided code
1040 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
1041
1042 if (strided_fn == nullptr) {
1043 throw std::runtime_error(
1044 "Strided implementation is missing for src1_typeid=" +
1045 std::to_string(src1_typeid) +
1046 " and src2_typeid=" + std::to_string(src2_typeid));
1047 }
1048
1049 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1050 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1051 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
1052 simplified_src2_strides, simplified_dst1_strides,
1053 simplified_dst2_strides);
1054 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1055 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1056
1057 const py::ssize_t *shape_strides = shape_strides_owner.get();
1058
1059 sycl::event strided_fn_ev =
1060 strided_fn(exec_q, src_nelems, nd, shape_strides, src1_data,
1061 src1_offset, src2_data, src2_offset, dst1_data, dst1_offset,
1062 dst2_data, dst2_offset, depends, {copy_shape_ev});
1063
1064 // async free of shape_strides temporary
1065 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1066 exec_q, {strided_fn_ev}, shape_strides_owner);
1067 host_tasks.push_back(tmp_cleanup_ev);
1068
1069 return std::make_pair(dpctl::utils::keep_args_alive(
1070 exec_q, {src1, src2, dst1, dst2}, host_tasks),
1071 strided_fn_ev);
1072}
1073
1078template <typename output_typesT>
1079std::pair<py::object, py::object> py_binary_two_outputs_ufunc_result_type(
1080 const py::dtype &input1_dtype,
1081 const py::dtype &input2_dtype,
1082 const output_typesT &output_types_table)
1083{
1084 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
1085 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
1086 int src1_typeid = -1;
1087 int src2_typeid = -1;
1088
1089 auto array_types = td_ns::usm_ndarray_types();
1090
1091 try {
1092 src1_typeid = array_types.typenum_to_lookup_id(tn1);
1093 src2_typeid = array_types.typenum_to_lookup_id(tn2);
1094 } catch (const std::exception &e) {
1095 throw py::value_error(e.what());
1096 }
1097
1098 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
1099 src2_typeid >= td_ns::num_types) {
1100 throw std::runtime_error("binary output type lookup failed");
1101 }
1102 std::pair<int, int> dst_typeids =
1103 output_types_table[src1_typeid][src2_typeid];
1104 int dst1_typeid = dst_typeids.first;
1105 int dst2_typeid = dst_typeids.second;
1106
1107 if (dst1_typeid < 0 || dst2_typeid < 0) {
1108 auto res = py::none();
1109 auto py_res = py::cast<py::object>(res);
1110 return std::make_pair(py_res, py_res);
1111 }
1112 else {
1113 using type_utils::_dtype_from_typenum;
1114
1115 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
1116 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
1117 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
1118 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
1119
1120 return std::make_pair(py::cast<py::object>(dt1),
1121 py::cast<py::object>(dt2));
1122 }
1123}
1124
1125// ==================== Inplace binary functions =======================
1126
1127template <typename output_typesT,
1128 typename contig_dispatchT,
1129 typename strided_dispatchT,
1130 typename contig_row_matrix_dispatchT>
1131std::pair<sycl::event, sycl::event>
1132 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
1133 const dpctl::tensor::usm_ndarray &rhs,
1134 sycl::queue &exec_q,
1135 const std::vector<sycl::event> &depends,
1136 //
1137 const output_typesT &output_type_table,
1138 const contig_dispatchT &contig_dispatch_table,
1139 const strided_dispatchT &strided_dispatch_table,
1140 const contig_row_matrix_dispatchT
1141 &contig_row_matrix_broadcast_dispatch_table)
1142{
1143 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
1144
1145 // check type_nums
1146 int rhs_typenum = rhs.get_typenum();
1147 int lhs_typenum = lhs.get_typenum();
1148
1149 auto array_types = td_ns::usm_ndarray_types();
1150 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
1151 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
1152
1153 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
1154
1155 if (output_typeid != lhs_typeid) {
1156 throw py::value_error(
1157 "Left-hand side array has unexpected elemental data type.");
1158 }
1159
1160 // check that queues are compatible
1161 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
1162 throw py::value_error(
1163 "Execution queue is not compatible with allocation queues");
1164 }
1165
1166 // check shapes, broadcasting is assumed done by caller
1167 // check that dimensions are the same
1168 int lhs_nd = lhs.get_ndim();
1169 if (lhs_nd != rhs.get_ndim()) {
1170 throw py::value_error("Array dimensions are not the same.");
1171 }
1172
1173 // check that shapes are the same
1174 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
1175 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
1176 bool shapes_equal(true);
1177 std::size_t rhs_nelems(1);
1178
1179 for (int i = 0; i < lhs_nd; ++i) {
1180 rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
1181 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
1182 }
1183 if (!shapes_equal) {
1184 throw py::value_error("Array shapes are not the same.");
1185 }
1186
1187 // if nelems is zero, return
1188 if (rhs_nelems == 0) {
1189 return std::make_pair(sycl::event(), sycl::event());
1190 }
1191
1192 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
1193
1194 // check memory overlap
1195 auto const &same_logical_tensors =
1196 dpctl::tensor::overlap::SameLogicalTensors();
1197 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
1198 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
1199 throw py::value_error("Arrays index overlapping segments of memory");
1200 }
1201 // check memory overlap
1202 const char *rhs_data = rhs.get_data();
1203 char *lhs_data = lhs.get_data();
1204
1205 // handle contiguous inputs
1206 bool is_rhs_c_contig = rhs.is_c_contiguous();
1207 bool is_rhs_f_contig = rhs.is_f_contiguous();
1208
1209 bool is_lhs_c_contig = lhs.is_c_contiguous();
1210 bool is_lhs_f_contig = lhs.is_f_contiguous();
1211
1212 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
1213 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
1214
1215 // dispatch for contiguous inputs
1216 if (both_c_contig || both_f_contig) {
1217 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1218
1219 if (contig_fn != nullptr) {
1220 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
1221 0, depends);
1222 sycl::event ht_ev =
1223 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
1224
1225 return std::make_pair(ht_ev, comp_ev);
1226 }
1227 }
1228
1229 // simplify strides
1230 auto const &rhs_strides = rhs.get_strides_vector();
1231 auto const &lhs_strides = lhs.get_strides_vector();
1232
1233 using shT = std::vector<py::ssize_t>;
1234 shT simplified_shape;
1235 shT simplified_rhs_strides;
1236 shT simplified_lhs_strides;
1237 py::ssize_t rhs_offset(0);
1238 py::ssize_t lhs_offset(0);
1239
1240 int nd = lhs_nd;
1241 const py::ssize_t *shape = rhs_shape;
1242
1243 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
1244 // outputs
1245 simplified_shape, simplified_rhs_strides,
1246 simplified_lhs_strides, rhs_offset, lhs_offset);
1247
1248 std::vector<sycl::event> host_tasks{};
1249 if (nd < 3) {
1250 static constexpr auto unit_stride =
1251 std::initializer_list<py::ssize_t>{1};
1252
1253 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
1254 isEqual(simplified_lhs_strides, unit_stride)) {
1255 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1256
1257 if (contig_fn != nullptr) {
1258 auto comp_ev =
1259 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
1260 lhs_data, lhs_offset, depends);
1261 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1262 exec_q, {rhs, lhs}, {comp_ev});
1263
1264 return std::make_pair(ht_ev, comp_ev);
1265 }
1266 }
1267 if (nd == 2) {
1268 static constexpr auto one_zero_strides =
1269 std::initializer_list<py::ssize_t>{1, 0};
1270 static constexpr py::ssize_t one{1};
1271 // special case of C-contiguous matrix and a row
1272 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
1273 isEqual(simplified_lhs_strides, {one, simplified_shape[0]})) {
1274 auto row_matrix_broadcast_fn =
1275 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
1276 [lhs_typeid];
1277 if (row_matrix_broadcast_fn != nullptr) {
1278 std::size_t n0 = simplified_shape[1];
1279 std::size_t n1 = simplified_shape[0];
1280 sycl::event comp_ev = row_matrix_broadcast_fn(
1281 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
1282 lhs_data, lhs_offset, depends);
1283
1284 return std::make_pair(dpctl::utils::keep_args_alive(
1285 exec_q, {lhs, rhs}, host_tasks),
1286 comp_ev);
1287 }
1288 }
1289 }
1290 }
1291
1292 // dispatch to strided code
1293 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
1294
1295 if (strided_fn == nullptr) {
1296 throw std::runtime_error(
1297 "Strided implementation is missing for rhs_typeid=" +
1298 std::to_string(rhs_typeid) +
1299 " and lhs_typeid=" + std::to_string(lhs_typeid));
1300 }
1301
1302 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1303 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1304 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
1305 simplified_lhs_strides);
1306 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1307 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1308
1309 const py::ssize_t *shape_strides = shape_strides_owner.get();
1310
1311 sycl::event strided_fn_ev =
1312 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
1313 lhs_data, lhs_offset, depends, {copy_shape_ev});
1314
1315 // async free of shape_strides temporary
1316 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1317 exec_q, {strided_fn_ev}, shape_strides_owner);
1318
1319 host_tasks.push_back(tmp_cleanup_ev);
1320
1321 return std::make_pair(
1322 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
1323 strided_fn_ev);
1324}
1325} // namespace dpnp::extensions::py_internal