DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <exception>
33#include <iterator>
34#include <stdexcept>
35#include <utility>
36#include <vector>
37
38#include <pybind11/numpy.h>
39#include <pybind11/pybind11.h>
40
41#include <sycl/sycl.hpp>
42
43#include "dpnp4pybind11.hpp"
44
45#include "elementwise_functions_type_utils.hpp"
46#include "simplify_iteration_space.hpp"
47
48// dpctl tensor headers
49#include "kernels/alignment.hpp"
50#include "utils/memory_overlap.hpp"
51#include "utils/offset_utils.hpp"
52#include "utils/output_validation.hpp"
53#include "utils/sycl_alloc_utils.hpp"
54#include "utils/type_dispatch.hpp"
55
56static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
57
58namespace dpnp::extensions::py_internal
59{
60namespace py = pybind11;
61namespace td_ns = dpctl::tensor::type_dispatch;
62
63using dpctl::tensor::kernels::alignment_utils::is_aligned;
64using dpctl::tensor::kernels::alignment_utils::required_alignment;
65
66using type_utils::_result_typeid;
67
69template <typename output_typesT,
70 typename contig_dispatchT,
71 typename strided_dispatchT>
72std::pair<sycl::event, sycl::event>
73 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
75 sycl::queue &q,
76 const std::vector<sycl::event> &depends,
77 //
78 const output_typesT &output_type_vec,
79 const contig_dispatchT &contig_dispatch_vector,
80 const strided_dispatchT &strided_dispatch_vector)
81{
82 int src_typenum = src.get_typenum();
83 int dst_typenum = dst.get_typenum();
84
85 const auto &array_types = td_ns::usm_ndarray_types();
86 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
87 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
88
89 int func_output_typeid = output_type_vec[src_typeid];
90
91 // check that types are supported
92 if (dst_typeid != func_output_typeid) {
93 throw py::value_error(
94 "Destination array has unexpected elemental data type.");
95 }
96
97 // check that queues are compatible
98 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
99 throw py::value_error(
100 "Execution queue is not compatible with allocation queues");
101 }
102
103 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
104
105 // check that dimensions are the same
106 int src_nd = src.get_ndim();
107 if (src_nd != dst.get_ndim()) {
108 throw py::value_error("Array dimensions are not the same.");
109 }
110
111 // check that shapes are the same
112 const py::ssize_t *src_shape = src.get_shape_raw();
113 const py::ssize_t *dst_shape = dst.get_shape_raw();
114 bool shapes_equal(true);
115 std::size_t src_nelems(1);
116
117 for (int i = 0; i < src_nd; ++i) {
118 src_nelems *= static_cast<std::size_t>(src_shape[i]);
119 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
120 }
121 if (!shapes_equal) {
122 throw py::value_error("Array shapes are not the same.");
123 }
124
125 // if nelems is zero, return
126 if (src_nelems == 0) {
127 return std::make_pair(sycl::event(), sycl::event());
128 }
129
130 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
131
132 // check memory overlap
133 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
134 auto const &same_logical_tensors =
135 dpctl::tensor::overlap::SameLogicalTensors();
136 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
137 throw py::value_error("Arrays index overlapping segments of memory");
138 }
139
140 const char *src_data = src.get_data();
141 char *dst_data = dst.get_data();
142
143 // handle contiguous inputs
144 bool is_src_c_contig = src.is_c_contiguous();
145 bool is_src_f_contig = src.is_f_contiguous();
146
147 bool is_dst_c_contig = dst.is_c_contiguous();
148 bool is_dst_f_contig = dst.is_f_contiguous();
149
150 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
151 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
152
153 if (both_c_contig || both_f_contig) {
154 auto contig_fn = contig_dispatch_vector[src_typeid];
155
156 if (contig_fn == nullptr) {
157 throw std::runtime_error(
158 "Contiguous implementation is missing for src_typeid=" +
159 std::to_string(src_typeid));
160 }
161
162 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
163 sycl::event ht_ev =
164 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
165
166 return std::make_pair(ht_ev, comp_ev);
167 }
168
169 // simplify iteration space
170 // if 1d with strides 1 - input is contig
171 // dispatch to strided
172
173 auto const &src_strides = src.get_strides_vector();
174 auto const &dst_strides = dst.get_strides_vector();
175
176 using shT = std::vector<py::ssize_t>;
177 shT simplified_shape;
178 shT simplified_src_strides;
179 shT simplified_dst_strides;
180 py::ssize_t src_offset(0);
181 py::ssize_t dst_offset(0);
182
183 int nd = src_nd;
184 const py::ssize_t *shape = src_shape;
185
186 simplify_iteration_space(nd, shape, src_strides, dst_strides,
187 // output
188 simplified_shape, simplified_src_strides,
189 simplified_dst_strides, src_offset, dst_offset);
190
191 if (nd == 1 && simplified_src_strides[0] == 1 &&
192 simplified_dst_strides[0] == 1) {
193 // Special case of contiguous data
194 auto contig_fn = contig_dispatch_vector[src_typeid];
195
196 if (contig_fn == nullptr) {
197 throw std::runtime_error(
198 "Contiguous implementation is missing for src_typeid=" +
199 std::to_string(src_typeid));
200 }
201
202 int src_elem_size = src.get_elemsize();
203 int dst_elem_size = dst.get_elemsize();
204 auto comp_ev =
205 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
206 dst_data + dst_elem_size * dst_offset, depends);
207
208 sycl::event ht_ev =
209 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
210
211 return std::make_pair(ht_ev, comp_ev);
212 }
213
214 // Strided implementation
215 auto strided_fn = strided_dispatch_vector[src_typeid];
216
217 if (strided_fn == nullptr) {
218 throw std::runtime_error(
219 "Strided implementation is missing for src_typeid=" +
220 std::to_string(src_typeid));
221 }
222
223 using dpctl::tensor::offset_utils::device_allocate_and_pack;
224
225 std::vector<sycl::event> host_tasks{};
226 host_tasks.reserve(2);
227
228 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
229 q, host_tasks, simplified_shape, simplified_src_strides,
230 simplified_dst_strides);
231 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
232 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
233 const py::ssize_t *shape_strides = shape_strides_owner.get();
234
235 sycl::event strided_fn_ev =
236 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
237 dst_data, dst_offset, depends, {copy_shape_ev});
238
239 // async free of shape_strides temporary
240 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
241 q, {strided_fn_ev}, shape_strides_owner);
242
243 host_tasks.push_back(tmp_cleanup_ev);
244
245 return std::make_pair(
246 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
247 strided_fn_ev);
248}
249
252template <typename output_typesT>
253py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
254 const output_typesT &output_types)
255{
256 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
257 int src_typeid = -1;
258
259 auto array_types = td_ns::usm_ndarray_types();
260
261 try {
262 src_typeid = array_types.typenum_to_lookup_id(tn);
263 } catch (const std::exception &e) {
264 throw py::value_error(e.what());
265 }
266
267 int dst_typeid = _result_typeid(src_typeid, output_types);
268 if (dst_typeid < 0) {
269 auto res = py::none();
270 return py::cast<py::object>(res);
271 }
272 else {
273 using type_utils::_dtype_from_typenum;
274
275 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
276 auto dt = _dtype_from_typenum(dst_typenum_t);
277
278 return py::cast<py::object>(dt);
279 }
280}
281
286template <typename output_typesT,
287 typename contig_dispatchT,
288 typename strided_dispatchT>
289std::pair<sycl::event, sycl::event>
290 py_unary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src,
291 const dpctl::tensor::usm_ndarray &dst1,
292 const dpctl::tensor::usm_ndarray &dst2,
293 sycl::queue &q,
294 const std::vector<sycl::event> &depends,
295 //
296 const output_typesT &output_type_vec,
297 const contig_dispatchT &contig_dispatch_vector,
298 const strided_dispatchT &strided_dispatch_vector)
299{
300 int src_typenum = src.get_typenum();
301 int dst1_typenum = dst1.get_typenum();
302 int dst2_typenum = dst2.get_typenum();
303
304 const auto &array_types = td_ns::usm_ndarray_types();
305 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
306 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
307 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
308
309 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
310
311 // check that types are supported
312 if (dst1_typeid != func_output_typeids.first ||
313 dst2_typeid != func_output_typeids.second)
314 {
315 throw py::value_error(
316 "One of destination arrays has unexpected elemental data type.");
317 }
318
319 // check that queues are compatible
320 if (!dpctl::utils::queues_are_compatible(q, {src, dst1, dst2})) {
321 throw py::value_error(
322 "Execution queue is not compatible with allocation queues");
323 }
324
325 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
326 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
327
328 // check that dimensions are the same
329 int src_nd = src.get_ndim();
330 if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
331 throw py::value_error("Array dimensions are not the same.");
332 }
333
334 // check that shapes are the same
335 const py::ssize_t *src_shape = src.get_shape_raw();
336 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
337 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
338 bool shapes_equal(true);
339 std::size_t src_nelems(1);
340
341 for (int i = 0; i < src_nd; ++i) {
342 src_nelems *= static_cast<std::size_t>(src_shape[i]);
343 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
344 (src_shape[i] == dst2_shape[i]);
345 }
346 if (!shapes_equal) {
347 throw py::value_error("Array shapes are not the same.");
348 }
349
350 // if nelems is zero, return
351 if (src_nelems == 0) {
352 return std::make_pair(sycl::event(), sycl::event());
353 }
354
355 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
356 src_nelems);
357 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
358 src_nelems);
359
360 // check memory overlap
361 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
362 auto const &same_logical_tensors =
363 dpctl::tensor::overlap::SameLogicalTensors();
364 if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
365 (overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
366 (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2)))
367 {
368 throw py::value_error("Arrays index overlapping segments of memory");
369 }
370
371 const char *src_data = src.get_data();
372 char *dst1_data = dst1.get_data();
373 char *dst2_data = dst2.get_data();
374
375 // handle contiguous inputs
376 bool is_src_c_contig = src.is_c_contiguous();
377 bool is_src_f_contig = src.is_f_contiguous();
378
379 bool is_dst1_c_contig = dst1.is_c_contiguous();
380 bool is_dst1_f_contig = dst1.is_f_contiguous();
381
382 bool is_dst2_c_contig = dst2.is_c_contiguous();
383 bool is_dst2_f_contig = dst2.is_f_contiguous();
384
385 bool all_c_contig =
386 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
387 bool all_f_contig =
388 (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
389
390 if (all_c_contig || all_f_contig) {
391 auto contig_fn = contig_dispatch_vector[src_typeid];
392
393 if (contig_fn == nullptr) {
394 throw std::runtime_error(
395 "Contiguous implementation is missing for src_typeid=" +
396 std::to_string(src_typeid));
397 }
398
399 auto comp_ev =
400 contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
401 sycl::event ht_ev =
402 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
403
404 return std::make_pair(ht_ev, comp_ev);
405 }
406
407 // simplify iteration space
408 // if 1d with strides 1 - input is contig
409 // dispatch to strided
410
411 auto const &src_strides = src.get_strides_vector();
412 auto const &dst1_strides = dst1.get_strides_vector();
413 auto const &dst2_strides = dst2.get_strides_vector();
414
415 using shT = std::vector<py::ssize_t>;
416 shT simplified_shape;
417 shT simplified_src_strides;
418 shT simplified_dst1_strides;
419 shT simplified_dst2_strides;
420 py::ssize_t src_offset(0);
421 py::ssize_t dst1_offset(0);
422 py::ssize_t dst2_offset(0);
423
424 int nd = src_nd;
425 const py::ssize_t *shape = src_shape;
426
427 simplify_iteration_space_3(
428 nd, shape, src_strides, dst1_strides, dst2_strides,
429 // output
430 simplified_shape, simplified_src_strides, simplified_dst1_strides,
431 simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
432
433 if (nd == 1 && simplified_src_strides[0] == 1 &&
434 simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1)
435 {
436 // Special case of contiguous data
437 auto contig_fn = contig_dispatch_vector[src_typeid];
438
439 if (contig_fn == nullptr) {
440 throw std::runtime_error(
441 "Contiguous implementation is missing for src_typeid=" +
442 std::to_string(src_typeid));
443 }
444
445 int src_elem_size = src.get_elemsize();
446 int dst1_elem_size = dst1.get_elemsize();
447 int dst2_elem_size = dst2.get_elemsize();
448 auto comp_ev =
449 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
450 dst1_data + dst1_elem_size * dst1_offset,
451 dst2_data + dst2_elem_size * dst2_offset, depends);
452
453 sycl::event ht_ev =
454 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
455
456 return std::make_pair(ht_ev, comp_ev);
457 }
458
459 // Strided implementation
460 auto strided_fn = strided_dispatch_vector[src_typeid];
461
462 if (strided_fn == nullptr) {
463 throw std::runtime_error(
464 "Strided implementation is missing for src_typeid=" +
465 std::to_string(src_typeid));
466 }
467
468 using dpctl::tensor::offset_utils::device_allocate_and_pack;
469
470 std::vector<sycl::event> host_tasks{};
471 host_tasks.reserve(2);
472
473 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
474 q, host_tasks, simplified_shape, simplified_src_strides,
475 simplified_dst1_strides, simplified_dst2_strides);
476 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
477 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
478 const py::ssize_t *shape_strides = shape_strides_owner.get();
479
480 sycl::event strided_fn_ev = strided_fn(
481 q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
482 dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
483
484 // async free of shape_strides temporary
485 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
486 q, {strided_fn_ev}, shape_strides_owner);
487
488 host_tasks.push_back(tmp_cleanup_ev);
489
490 return std::make_pair(
491 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
492 strided_fn_ev);
493}
494
499template <typename output_typesT>
500std::pair<py::object, py::object>
501 py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
502 const output_typesT &output_types)
503{
504 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
505 int src_typeid = -1;
506
507 auto array_types = td_ns::usm_ndarray_types();
508
509 try {
510 src_typeid = array_types.typenum_to_lookup_id(tn);
511 } catch (const std::exception &e) {
512 throw py::value_error(e.what());
513 }
514
515 std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
516 int dst1_typeid = dst_typeids.first;
517 int dst2_typeid = dst_typeids.second;
518
519 if (dst1_typeid < 0 || dst2_typeid < 0) {
520 auto res = py::none();
521 auto py_res = py::cast<py::object>(res);
522 return std::make_pair(py_res, py_res);
523 }
524 else {
525 using type_utils::_dtype_from_typenum;
526
527 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
528 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
529 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
530 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
531
532 return std::make_pair(py::cast<py::object>(dt1),
533 py::cast<py::object>(dt2));
534 }
535}
536
537// ======================== Binary functions ===========================
538
539namespace
540{
541template <class Container, class T>
542bool isEqual(Container const &c, std::initializer_list<T> const &l)
543{
544 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
545}
546} // namespace
547
550template <typename output_typesT,
551 typename contig_dispatchT,
552 typename strided_dispatchT,
553 typename contig_matrix_row_dispatchT,
554 typename contig_row_matrix_dispatchT>
555std::pair<sycl::event, sycl::event> py_binary_ufunc(
556 const dpctl::tensor::usm_ndarray &src1,
557 const dpctl::tensor::usm_ndarray &src2,
558 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
559 sycl::queue &exec_q,
560 const std::vector<sycl::event> &depends,
561 //
562 const output_typesT &output_type_table,
563 const contig_dispatchT &contig_dispatch_table,
564 const strided_dispatchT &strided_dispatch_table,
565 const contig_matrix_row_dispatchT
566 &contig_matrix_row_broadcast_dispatch_table,
567 const contig_row_matrix_dispatchT
568 &contig_row_matrix_broadcast_dispatch_table)
569{
570 // check type_nums
571 int src1_typenum = src1.get_typenum();
572 int src2_typenum = src2.get_typenum();
573 int dst_typenum = dst.get_typenum();
574
575 auto array_types = td_ns::usm_ndarray_types();
576 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
577 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
578 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
579
580 int output_typeid = output_type_table[src1_typeid][src2_typeid];
581
582 if (output_typeid != dst_typeid) {
583 throw py::value_error(
584 "Destination array has unexpected elemental data type.");
585 }
586
587 // check that queues are compatible
588 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
589 throw py::value_error(
590 "Execution queue is not compatible with allocation queues");
591 }
592
593 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
594
595 // check shapes, broadcasting is assumed done by caller
596 // check that dimensions are the same
597 int dst_nd = dst.get_ndim();
598 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
599 throw py::value_error("Array dimensions are not the same.");
600 }
601
602 // check that shapes are the same
603 const py::ssize_t *src1_shape = src1.get_shape_raw();
604 const py::ssize_t *src2_shape = src2.get_shape_raw();
605 const py::ssize_t *dst_shape = dst.get_shape_raw();
606 bool shapes_equal(true);
607 std::size_t src_nelems(1);
608
609 for (int i = 0; i < dst_nd; ++i) {
610 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
611 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
612 src2_shape[i] == dst_shape[i]);
613 }
614 if (!shapes_equal) {
615 throw py::value_error("Array shapes are not the same.");
616 }
617
618 // if nelems is zero, return
619 if (src_nelems == 0) {
620 return std::make_pair(sycl::event(), sycl::event());
621 }
622
623 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
624
625 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
626 auto const &same_logical_tensors =
627 dpctl::tensor::overlap::SameLogicalTensors();
628 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
629 (overlap(src2, dst) && !same_logical_tensors(src2, dst)))
630 {
631 throw py::value_error("Arrays index overlapping segments of memory");
632 }
633 // check memory overlap
634 const char *src1_data = src1.get_data();
635 const char *src2_data = src2.get_data();
636 char *dst_data = dst.get_data();
637
638 // handle contiguous inputs
639 bool is_src1_c_contig = src1.is_c_contiguous();
640 bool is_src1_f_contig = src1.is_f_contiguous();
641
642 bool is_src2_c_contig = src2.is_c_contiguous();
643 bool is_src2_f_contig = src2.is_f_contiguous();
644
645 bool is_dst_c_contig = dst.is_c_contiguous();
646 bool is_dst_f_contig = dst.is_f_contiguous();
647
648 bool all_c_contig =
649 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
650 bool all_f_contig =
651 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
652
653 // dispatch for contiguous inputs
654 if (all_c_contig || all_f_contig) {
655 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
656
657 if (contig_fn != nullptr) {
658 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
659 src2_data, 0, dst_data, 0, depends);
660 sycl::event ht_ev = dpctl::utils::keep_args_alive(
661 exec_q, {src1, src2, dst}, {comp_ev});
662
663 return std::make_pair(ht_ev, comp_ev);
664 }
665 }
666
667 // simplify strides
668 auto const &src1_strides = src1.get_strides_vector();
669 auto const &src2_strides = src2.get_strides_vector();
670 auto const &dst_strides = dst.get_strides_vector();
671
672 using shT = std::vector<py::ssize_t>;
673 shT simplified_shape;
674 shT simplified_src1_strides;
675 shT simplified_src2_strides;
676 shT simplified_dst_strides;
677 py::ssize_t src1_offset(0);
678 py::ssize_t src2_offset(0);
679 py::ssize_t dst_offset(0);
680
681 int nd = dst_nd;
682 const py::ssize_t *shape = src1_shape;
683
684 simplify_iteration_space_3(
685 nd, shape, src1_strides, src2_strides, dst_strides,
686 // outputs
687 simplified_shape, simplified_src1_strides, simplified_src2_strides,
688 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
689
690 std::vector<sycl::event> host_tasks{};
691 if (nd < 3) {
692 static constexpr auto unit_stride =
693 std::initializer_list<py::ssize_t>{1};
694
695 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
696 isEqual(simplified_src2_strides, unit_stride) &&
697 isEqual(simplified_dst_strides, unit_stride))
698 {
699 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
700
701 if (contig_fn != nullptr) {
702 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
703 src1_offset, src2_data, src2_offset,
704 dst_data, dst_offset, depends);
705 sycl::event ht_ev = dpctl::utils::keep_args_alive(
706 exec_q, {src1, src2, dst}, {comp_ev});
707
708 return std::make_pair(ht_ev, comp_ev);
709 }
710 }
711 if (nd == 2) {
712 static constexpr auto zero_one_strides =
713 std::initializer_list<py::ssize_t>{0, 1};
714 static constexpr auto one_zero_strides =
715 std::initializer_list<py::ssize_t>{1, 0};
716 static constexpr py::ssize_t one{1};
717 // special case of C-contiguous matrix and a row
718 if (isEqual(simplified_src2_strides, zero_one_strides) &&
719 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
720 isEqual(simplified_dst_strides, {simplified_shape[1], one}))
721 {
722 auto matrix_row_broadcast_fn =
723 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
724 [src2_typeid];
725 if (matrix_row_broadcast_fn != nullptr) {
726 int src1_itemsize = src1.get_elemsize();
727 int src2_itemsize = src2.get_elemsize();
728 int dst_itemsize = dst.get_elemsize();
729
730 if (is_aligned<required_alignment>(
731 src1_data + src1_offset * src1_itemsize) &&
732 is_aligned<required_alignment>(
733 src2_data + src2_offset * src2_itemsize) &&
734 is_aligned<required_alignment>(
735 dst_data + dst_offset * dst_itemsize))
736 {
737 std::size_t n0 = simplified_shape[0];
738 std::size_t n1 = simplified_shape[1];
739 sycl::event comp_ev = matrix_row_broadcast_fn(
740 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
741 src2_data, src2_offset, dst_data, dst_offset,
742 depends);
743
744 return std::make_pair(
745 dpctl::utils::keep_args_alive(
746 exec_q, {src1, src2, dst}, host_tasks),
747 comp_ev);
748 }
749 }
750 }
751 if (isEqual(simplified_src1_strides, one_zero_strides) &&
752 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
753 isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
754 {
755 auto row_matrix_broadcast_fn =
756 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
757 [src2_typeid];
758 if (row_matrix_broadcast_fn != nullptr) {
759
760 int src1_itemsize = src1.get_elemsize();
761 int src2_itemsize = src2.get_elemsize();
762 int dst_itemsize = dst.get_elemsize();
763
764 if (is_aligned<required_alignment>(
765 src1_data + src1_offset * src1_itemsize) &&
766 is_aligned<required_alignment>(
767 src2_data + src2_offset * src2_itemsize) &&
768 is_aligned<required_alignment>(
769 dst_data + dst_offset * dst_itemsize))
770 {
771 std::size_t n0 = simplified_shape[1];
772 std::size_t n1 = simplified_shape[0];
773 sycl::event comp_ev = row_matrix_broadcast_fn(
774 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
775 src2_data, src2_offset, dst_data, dst_offset,
776 depends);
777
778 return std::make_pair(
779 dpctl::utils::keep_args_alive(
780 exec_q, {src1, src2, dst}, host_tasks),
781 comp_ev);
782 }
783 }
784 }
785 }
786 }
787
788 // dispatch to strided code
789 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
790
791 if (strided_fn == nullptr) {
792 throw std::runtime_error(
793 "Strided implementation is missing for src1_typeid=" +
794 std::to_string(src1_typeid) +
795 " and src2_typeid=" + std::to_string(src2_typeid));
796 }
797
798 using dpctl::tensor::offset_utils::device_allocate_and_pack;
799 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
800 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
801 simplified_src2_strides, simplified_dst_strides);
802 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
803 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
804
805 const py::ssize_t *shape_strides = shape_strides_owner.get();
806
807 sycl::event strided_fn_ev = strided_fn(
808 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
809 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
810
811 // async free of shape_strides temporary
812 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
813 exec_q, {strided_fn_ev}, shape_strides_owner);
814
815 host_tasks.push_back(tmp_cleanup_ev);
816
817 return std::make_pair(
818 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
819 strided_fn_ev);
820}
821
823template <typename output_typesT>
824py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
825 const py::dtype &input2_dtype,
826 const output_typesT &output_types_table)
827{
828 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
829 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
830 int src1_typeid = -1;
831 int src2_typeid = -1;
832
833 auto array_types = td_ns::usm_ndarray_types();
834
835 try {
836 src1_typeid = array_types.typenum_to_lookup_id(tn1);
837 src2_typeid = array_types.typenum_to_lookup_id(tn2);
838 } catch (const std::exception &e) {
839 throw py::value_error(e.what());
840 }
841
842 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
843 src2_typeid >= td_ns::num_types)
844 {
845 throw std::runtime_error("binary output type lookup failed");
846 }
847 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
848
849 if (dst_typeid < 0) {
850 auto res = py::none();
851 return py::cast<py::object>(res);
852 }
853 else {
854 using type_utils::_dtype_from_typenum;
855
856 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
857 auto dt = _dtype_from_typenum(dst_typenum_t);
858
859 return py::cast<py::object>(dt);
860 }
861}
862
865template <typename output_typesT,
866 typename contig_dispatchT,
867 typename strided_dispatchT>
868std::pair<sycl::event, sycl::event>
869 py_binary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src1,
870 const dpctl::tensor::usm_ndarray &src2,
871 const dpctl::tensor::usm_ndarray &dst1,
872 const dpctl::tensor::usm_ndarray &dst2,
873 sycl::queue &exec_q,
874 const std::vector<sycl::event> &depends,
875 //
876 const output_typesT &output_types_table,
877 const contig_dispatchT &contig_dispatch_table,
878 const strided_dispatchT &strided_dispatch_table)
879{
880 // check type_nums
881 int src1_typenum = src1.get_typenum();
882 int src2_typenum = src2.get_typenum();
883 int dst1_typenum = dst1.get_typenum();
884 int dst2_typenum = dst2.get_typenum();
885
886 auto array_types = td_ns::usm_ndarray_types();
887 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
888 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
889 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
890 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
891
892 std::pair<int, int> output_typeids =
893 output_types_table[src1_typeid][src2_typeid];
894
895 if (dst1_typeid != output_typeids.first ||
896 dst2_typeid != output_typeids.second) {
897 throw py::value_error(
898 "One of destination arrays has unexpected elemental data type.");
899 }
900
901 // check that queues are compatible
902 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst1, dst2}))
903 {
904 throw py::value_error(
905 "Execution queue is not compatible with allocation queues");
906 }
907
908 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
909 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
910
911 // check shapes, broadcasting is assumed done by caller
912 // check that dimensions are the same
913 int src1_nd = src1.get_ndim();
914 int src2_nd = src2.get_ndim();
915 int dst1_nd = dst1.get_ndim();
916 int dst2_nd = dst2.get_ndim();
917
918 if (dst1_nd != src1_nd || dst1_nd != src2_nd || dst1_nd != dst2_nd) {
919 throw py::value_error("Array dimensions are not the same.");
920 }
921
922 // check that shapes are the same
923 const py::ssize_t *src1_shape = src1.get_shape_raw();
924 const py::ssize_t *src2_shape = src2.get_shape_raw();
925 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
926 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
927 bool shapes_equal(true);
928 std::size_t src_nelems(1);
929
930 for (int i = 0; i < dst1_nd; ++i) {
931 const auto &sh_i = dst1_shape[i];
932 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
933 shapes_equal =
934 shapes_equal && (src1_shape[i] == sh_i && src2_shape[i] == sh_i &&
935 dst2_shape[i] == sh_i);
936 }
937 if (!shapes_equal) {
938 throw py::value_error("Array shapes are not the same.");
939 }
940
941 // if nelems is zero, return
942 if (src_nelems == 0) {
943 return std::make_pair(sycl::event(), sycl::event());
944 }
945
946 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
947 src_nelems);
948 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
949 src_nelems);
950
951 // check memory overlap
952 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
953 auto const &same_logical_tensors =
954 dpctl::tensor::overlap::SameLogicalTensors();
955 if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) ||
956 (overlap(src1, dst2) && !same_logical_tensors(src1, dst2)) ||
957 (overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) ||
958 (overlap(src2, dst2) && !same_logical_tensors(src2, dst2)) ||
959 (overlap(dst1, dst2)))
960 {
961 throw py::value_error("Arrays index overlapping segments of memory");
962 }
963
964 const char *src1_data = src1.get_data();
965 const char *src2_data = src2.get_data();
966 char *dst1_data = dst1.get_data();
967 char *dst2_data = dst2.get_data();
968
969 // handle contiguous inputs
970 bool is_src1_c_contig = src1.is_c_contiguous();
971 bool is_src1_f_contig = src1.is_f_contiguous();
972
973 bool is_src2_c_contig = src2.is_c_contiguous();
974 bool is_src2_f_contig = src2.is_f_contiguous();
975
976 bool is_dst1_c_contig = dst1.is_c_contiguous();
977 bool is_dst1_f_contig = dst1.is_f_contiguous();
978
979 bool is_dst2_c_contig = dst2.is_c_contiguous();
980 bool is_dst2_f_contig = dst2.is_f_contiguous();
981
982 bool all_c_contig = (is_src1_c_contig && is_src2_c_contig &&
983 is_dst1_c_contig && is_dst2_c_contig);
984 bool all_f_contig = (is_src1_f_contig && is_src2_f_contig &&
985 is_dst1_f_contig && is_dst2_f_contig);
986
987 // dispatch for contiguous inputs
988 if (all_c_contig || all_f_contig) {
989 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
990
991 if (contig_fn != nullptr) {
992 auto comp_ev =
993 contig_fn(exec_q, src_nelems, src1_data, 0, src2_data, 0,
994 dst1_data, 0, dst2_data, 0, depends);
995 sycl::event ht_ev = dpctl::utils::keep_args_alive(
996 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
997
998 return std::make_pair(ht_ev, comp_ev);
999 }
1000 }
1001
1002 // simplify strides
1003 auto const &src1_strides = src1.get_strides_vector();
1004 auto const &src2_strides = src2.get_strides_vector();
1005 auto const &dst1_strides = dst1.get_strides_vector();
1006 auto const &dst2_strides = dst2.get_strides_vector();
1007
1008 using shT = std::vector<py::ssize_t>;
1009 shT simplified_shape;
1010 shT simplified_src1_strides;
1011 shT simplified_src2_strides;
1012 shT simplified_dst1_strides;
1013 shT simplified_dst2_strides;
1014 py::ssize_t src1_offset(0);
1015 py::ssize_t src2_offset(0);
1016 py::ssize_t dst1_offset(0);
1017 py::ssize_t dst2_offset(0);
1018
1019 int nd = dst1_nd;
1020 const py::ssize_t *shape = src1_shape;
1021
1022 simplify_iteration_space_4(
1023 nd, shape, src1_strides, src2_strides, dst1_strides, dst2_strides,
1024 // outputs
1025 simplified_shape, simplified_src1_strides, simplified_src2_strides,
1026 simplified_dst1_strides, simplified_dst2_strides, src1_offset,
1027 src2_offset, dst1_offset, dst2_offset);
1028
1029 std::vector<sycl::event> host_tasks{};
1030 static constexpr auto unit_stride = std::initializer_list<py::ssize_t>{1};
1031
1032 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
1033 isEqual(simplified_src2_strides, unit_stride) &&
1034 isEqual(simplified_dst1_strides, unit_stride) &&
1035 isEqual(simplified_dst2_strides, unit_stride))
1036 {
1037 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
1038
1039 if (contig_fn != nullptr) {
1040 auto comp_ev =
1041 contig_fn(exec_q, src_nelems, src1_data, src1_offset, src2_data,
1042 src2_offset, dst1_data, dst1_offset, dst2_data,
1043 dst2_offset, depends);
1044 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1045 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
1046
1047 return std::make_pair(ht_ev, comp_ev);
1048 }
1049 }
1050
1051 // dispatch to strided code
1052 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
1053
1054 if (strided_fn == nullptr) {
1055 throw std::runtime_error(
1056 "Strided implementation is missing for src1_typeid=" +
1057 std::to_string(src1_typeid) +
1058 " and src2_typeid=" + std::to_string(src2_typeid));
1059 }
1060
1061 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1062 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1063 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
1064 simplified_src2_strides, simplified_dst1_strides,
1065 simplified_dst2_strides);
1066 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1067 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1068
1069 const py::ssize_t *shape_strides = shape_strides_owner.get();
1070
1071 sycl::event strided_fn_ev =
1072 strided_fn(exec_q, src_nelems, nd, shape_strides, src1_data,
1073 src1_offset, src2_data, src2_offset, dst1_data, dst1_offset,
1074 dst2_data, dst2_offset, depends, {copy_shape_ev});
1075
1076 // async free of shape_strides temporary
1077 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1078 exec_q, {strided_fn_ev}, shape_strides_owner);
1079 host_tasks.push_back(tmp_cleanup_ev);
1080
1081 return std::make_pair(dpctl::utils::keep_args_alive(
1082 exec_q, {src1, src2, dst1, dst2}, host_tasks),
1083 strided_fn_ev);
1084}
1085
1090template <typename output_typesT>
1091std::pair<py::object, py::object> py_binary_two_outputs_ufunc_result_type(
1092 const py::dtype &input1_dtype,
1093 const py::dtype &input2_dtype,
1094 const output_typesT &output_types_table)
1095{
1096 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
1097 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
1098 int src1_typeid = -1;
1099 int src2_typeid = -1;
1100
1101 auto array_types = td_ns::usm_ndarray_types();
1102
1103 try {
1104 src1_typeid = array_types.typenum_to_lookup_id(tn1);
1105 src2_typeid = array_types.typenum_to_lookup_id(tn2);
1106 } catch (const std::exception &e) {
1107 throw py::value_error(e.what());
1108 }
1109
1110 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
1111 src2_typeid >= td_ns::num_types)
1112 {
1113 throw std::runtime_error("binary output type lookup failed");
1114 }
1115 std::pair<int, int> dst_typeids =
1116 output_types_table[src1_typeid][src2_typeid];
1117 int dst1_typeid = dst_typeids.first;
1118 int dst2_typeid = dst_typeids.second;
1119
1120 if (dst1_typeid < 0 || dst2_typeid < 0) {
1121 auto res = py::none();
1122 auto py_res = py::cast<py::object>(res);
1123 return std::make_pair(py_res, py_res);
1124 }
1125 else {
1126 using type_utils::_dtype_from_typenum;
1127
1128 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
1129 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
1130 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
1131 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
1132
1133 return std::make_pair(py::cast<py::object>(dt1),
1134 py::cast<py::object>(dt2));
1135 }
1136}
1137
1138// ==================== Inplace binary functions =======================
1139
1140template <typename output_typesT,
1141 typename contig_dispatchT,
1142 typename strided_dispatchT,
1143 typename contig_row_matrix_dispatchT>
1144std::pair<sycl::event, sycl::event>
1145 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
1146 const dpctl::tensor::usm_ndarray &rhs,
1147 sycl::queue &exec_q,
1148 const std::vector<sycl::event> &depends,
1149 //
1150 const output_typesT &output_type_table,
1151 const contig_dispatchT &contig_dispatch_table,
1152 const strided_dispatchT &strided_dispatch_table,
1153 const contig_row_matrix_dispatchT
1154 &contig_row_matrix_broadcast_dispatch_table)
1155{
1156 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
1157
1158 // check type_nums
1159 int rhs_typenum = rhs.get_typenum();
1160 int lhs_typenum = lhs.get_typenum();
1161
1162 auto array_types = td_ns::usm_ndarray_types();
1163 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
1164 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
1165
1166 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
1167
1168 if (output_typeid != lhs_typeid) {
1169 throw py::value_error(
1170 "Left-hand side array has unexpected elemental data type.");
1171 }
1172
1173 // check that queues are compatible
1174 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
1175 throw py::value_error(
1176 "Execution queue is not compatible with allocation queues");
1177 }
1178
1179 // check shapes, broadcasting is assumed done by caller
1180 // check that dimensions are the same
1181 int lhs_nd = lhs.get_ndim();
1182 if (lhs_nd != rhs.get_ndim()) {
1183 throw py::value_error("Array dimensions are not the same.");
1184 }
1185
1186 // check that shapes are the same
1187 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
1188 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
1189 bool shapes_equal(true);
1190 std::size_t rhs_nelems(1);
1191
1192 for (int i = 0; i < lhs_nd; ++i) {
1193 rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
1194 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
1195 }
1196 if (!shapes_equal) {
1197 throw py::value_error("Array shapes are not the same.");
1198 }
1199
1200 // if nelems is zero, return
1201 if (rhs_nelems == 0) {
1202 return std::make_pair(sycl::event(), sycl::event());
1203 }
1204
1205 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
1206
1207 // check memory overlap
1208 auto const &same_logical_tensors =
1209 dpctl::tensor::overlap::SameLogicalTensors();
1210 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
1211 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
1212 throw py::value_error("Arrays index overlapping segments of memory");
1213 }
1214 // check memory overlap
1215 const char *rhs_data = rhs.get_data();
1216 char *lhs_data = lhs.get_data();
1217
1218 // handle contiguous inputs
1219 bool is_rhs_c_contig = rhs.is_c_contiguous();
1220 bool is_rhs_f_contig = rhs.is_f_contiguous();
1221
1222 bool is_lhs_c_contig = lhs.is_c_contiguous();
1223 bool is_lhs_f_contig = lhs.is_f_contiguous();
1224
1225 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
1226 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
1227
1228 // dispatch for contiguous inputs
1229 if (both_c_contig || both_f_contig) {
1230 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1231
1232 if (contig_fn != nullptr) {
1233 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
1234 0, depends);
1235 sycl::event ht_ev =
1236 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
1237
1238 return std::make_pair(ht_ev, comp_ev);
1239 }
1240 }
1241
1242 // simplify strides
1243 auto const &rhs_strides = rhs.get_strides_vector();
1244 auto const &lhs_strides = lhs.get_strides_vector();
1245
1246 using shT = std::vector<py::ssize_t>;
1247 shT simplified_shape;
1248 shT simplified_rhs_strides;
1249 shT simplified_lhs_strides;
1250 py::ssize_t rhs_offset(0);
1251 py::ssize_t lhs_offset(0);
1252
1253 int nd = lhs_nd;
1254 const py::ssize_t *shape = rhs_shape;
1255
1256 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
1257 // outputs
1258 simplified_shape, simplified_rhs_strides,
1259 simplified_lhs_strides, rhs_offset, lhs_offset);
1260
1261 std::vector<sycl::event> host_tasks{};
1262 if (nd < 3) {
1263 static constexpr auto unit_stride =
1264 std::initializer_list<py::ssize_t>{1};
1265
1266 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
1267 isEqual(simplified_lhs_strides, unit_stride))
1268 {
1269 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1270
1271 if (contig_fn != nullptr) {
1272 auto comp_ev =
1273 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
1274 lhs_data, lhs_offset, depends);
1275 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1276 exec_q, {rhs, lhs}, {comp_ev});
1277
1278 return std::make_pair(ht_ev, comp_ev);
1279 }
1280 }
1281 if (nd == 2) {
1282 static constexpr auto one_zero_strides =
1283 std::initializer_list<py::ssize_t>{1, 0};
1284 static constexpr py::ssize_t one{1};
1285 // special case of C-contiguous matrix and a row
1286 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
1287 isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
1288 {
1289 auto row_matrix_broadcast_fn =
1290 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
1291 [lhs_typeid];
1292 if (row_matrix_broadcast_fn != nullptr) {
1293 std::size_t n0 = simplified_shape[1];
1294 std::size_t n1 = simplified_shape[0];
1295 sycl::event comp_ev = row_matrix_broadcast_fn(
1296 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
1297 lhs_data, lhs_offset, depends);
1298
1299 return std::make_pair(dpctl::utils::keep_args_alive(
1300 exec_q, {lhs, rhs}, host_tasks),
1301 comp_ev);
1302 }
1303 }
1304 }
1305 }
1306
1307 // dispatch to strided code
1308 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
1309
1310 if (strided_fn == nullptr) {
1311 throw std::runtime_error(
1312 "Strided implementation is missing for rhs_typeid=" +
1313 std::to_string(rhs_typeid) +
1314 " and lhs_typeid=" + std::to_string(lhs_typeid));
1315 }
1316
1317 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1318 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1319 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
1320 simplified_lhs_strides);
1321 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1322 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1323
1324 const py::ssize_t *shape_strides = shape_strides_owner.get();
1325
1326 sycl::event strided_fn_ev =
1327 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
1328 lhs_data, lhs_offset, depends, {copy_shape_ev});
1329
1330 // async free of shape_strides temporary
1331 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1332 exec_q, {strided_fn_ev}, shape_strides_owner);
1333
1334 host_tasks.push_back(tmp_cleanup_ev);
1335
1336 return std::make_pair(
1337 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
1338 strided_fn_ev);
1339}
1340} // namespace dpnp::extensions::py_internal