DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <exception>
33#include <stdexcept>
34#include <utility>
35#include <vector>
36
37#include <sycl/sycl.hpp>
38
39#include "dpctl4pybind11.hpp"
40#include <pybind11/numpy.h>
41#include <pybind11/pybind11.h>
42#include <pybind11/stl.h>
43
44#include "elementwise_functions_type_utils.hpp"
45#include "simplify_iteration_space.hpp"
46
47// dpctl tensor headers
48#include "kernels/alignment.hpp"
49#include "utils/memory_overlap.hpp"
50#include "utils/offset_utils.hpp"
51#include "utils/output_validation.hpp"
52#include "utils/sycl_alloc_utils.hpp"
53#include "utils/type_dispatch.hpp"
54
55static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
56
57namespace dpnp::extensions::py_internal
58{
59namespace py = pybind11;
60namespace td_ns = dpctl::tensor::type_dispatch;
61
62using dpctl::tensor::kernels::alignment_utils::is_aligned;
63using dpctl::tensor::kernels::alignment_utils::required_alignment;
64
65using type_utils::_result_typeid;
66
68template <typename output_typesT,
69 typename contig_dispatchT,
70 typename strided_dispatchT>
71std::pair<sycl::event, sycl::event>
72 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
73 const dpctl::tensor::usm_ndarray &dst,
74 sycl::queue &q,
75 const std::vector<sycl::event> &depends,
76 //
77 const output_typesT &output_type_vec,
78 const contig_dispatchT &contig_dispatch_vector,
79 const strided_dispatchT &strided_dispatch_vector)
80{
81 int src_typenum = src.get_typenum();
82 int dst_typenum = dst.get_typenum();
83
84 const auto &array_types = td_ns::usm_ndarray_types();
85 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
86 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
87
88 int func_output_typeid = output_type_vec[src_typeid];
89
90 // check that types are supported
91 if (dst_typeid != func_output_typeid) {
92 throw py::value_error(
93 "Destination array has unexpected elemental data type.");
94 }
95
96 // check that queues are compatible
97 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
98 throw py::value_error(
99 "Execution queue is not compatible with allocation queues");
100 }
101
102 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
103
104 // check that dimensions are the same
105 int src_nd = src.get_ndim();
106 if (src_nd != dst.get_ndim()) {
107 throw py::value_error("Array dimensions are not the same.");
108 }
109
110 // check that shapes are the same
111 const py::ssize_t *src_shape = src.get_shape_raw();
112 const py::ssize_t *dst_shape = dst.get_shape_raw();
113 bool shapes_equal(true);
114 std::size_t src_nelems(1);
115
116 for (int i = 0; i < src_nd; ++i) {
117 src_nelems *= static_cast<std::size_t>(src_shape[i]);
118 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
119 }
120 if (!shapes_equal) {
121 throw py::value_error("Array shapes are not the same.");
122 }
123
124 // if nelems is zero, return
125 if (src_nelems == 0) {
126 return std::make_pair(sycl::event(), sycl::event());
127 }
128
129 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
130
131 // check memory overlap
132 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
133 auto const &same_logical_tensors =
134 dpctl::tensor::overlap::SameLogicalTensors();
135 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
136 throw py::value_error("Arrays index overlapping segments of memory");
137 }
138
139 const char *src_data = src.get_data();
140 char *dst_data = dst.get_data();
141
142 // handle contiguous inputs
143 bool is_src_c_contig = src.is_c_contiguous();
144 bool is_src_f_contig = src.is_f_contiguous();
145
146 bool is_dst_c_contig = dst.is_c_contiguous();
147 bool is_dst_f_contig = dst.is_f_contiguous();
148
149 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
150 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
151
152 if (both_c_contig || both_f_contig) {
153 auto contig_fn = contig_dispatch_vector[src_typeid];
154
155 if (contig_fn == nullptr) {
156 throw std::runtime_error(
157 "Contiguous implementation is missing for src_typeid=" +
158 std::to_string(src_typeid));
159 }
160
161 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
162 sycl::event ht_ev =
163 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
164
165 return std::make_pair(ht_ev, comp_ev);
166 }
167
168 // simplify iteration space
169 // if 1d with strides 1 - input is contig
170 // dispatch to strided
171
172 auto const &src_strides = src.get_strides_vector();
173 auto const &dst_strides = dst.get_strides_vector();
174
175 using shT = std::vector<py::ssize_t>;
176 shT simplified_shape;
177 shT simplified_src_strides;
178 shT simplified_dst_strides;
179 py::ssize_t src_offset(0);
180 py::ssize_t dst_offset(0);
181
182 int nd = src_nd;
183 const py::ssize_t *shape = src_shape;
184
185 simplify_iteration_space(nd, shape, src_strides, dst_strides,
186 // output
187 simplified_shape, simplified_src_strides,
188 simplified_dst_strides, src_offset, dst_offset);
189
190 if (nd == 1 && simplified_src_strides[0] == 1 &&
191 simplified_dst_strides[0] == 1) {
192 // Special case of contiguous data
193 auto contig_fn = contig_dispatch_vector[src_typeid];
194
195 if (contig_fn == nullptr) {
196 throw std::runtime_error(
197 "Contiguous implementation is missing for src_typeid=" +
198 std::to_string(src_typeid));
199 }
200
201 int src_elem_size = src.get_elemsize();
202 int dst_elem_size = dst.get_elemsize();
203 auto comp_ev =
204 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
205 dst_data + dst_elem_size * dst_offset, depends);
206
207 sycl::event ht_ev =
208 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
209
210 return std::make_pair(ht_ev, comp_ev);
211 }
212
213 // Strided implementation
214 auto strided_fn = strided_dispatch_vector[src_typeid];
215
216 if (strided_fn == nullptr) {
217 throw std::runtime_error(
218 "Strided implementation is missing for src_typeid=" +
219 std::to_string(src_typeid));
220 }
221
222 using dpctl::tensor::offset_utils::device_allocate_and_pack;
223
224 std::vector<sycl::event> host_tasks{};
225 host_tasks.reserve(2);
226
227 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
228 q, host_tasks, simplified_shape, simplified_src_strides,
229 simplified_dst_strides);
230 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
231 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
232 const py::ssize_t *shape_strides = shape_strides_owner.get();
233
234 sycl::event strided_fn_ev =
235 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
236 dst_data, dst_offset, depends, {copy_shape_ev});
237
238 // async free of shape_strides temporary
239 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
240 q, {strided_fn_ev}, shape_strides_owner);
241
242 host_tasks.push_back(tmp_cleanup_ev);
243
244 return std::make_pair(
245 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
246 strided_fn_ev);
247}
248
251template <typename output_typesT>
252py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
253 const output_typesT &output_types)
254{
255 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
256 int src_typeid = -1;
257
258 auto array_types = td_ns::usm_ndarray_types();
259
260 try {
261 src_typeid = array_types.typenum_to_lookup_id(tn);
262 } catch (const std::exception &e) {
263 throw py::value_error(e.what());
264 }
265
266 int dst_typeid = _result_typeid(src_typeid, output_types);
267 if (dst_typeid < 0) {
268 auto res = py::none();
269 return py::cast<py::object>(res);
270 }
271 else {
272 using type_utils::_dtype_from_typenum;
273
274 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
275 auto dt = _dtype_from_typenum(dst_typenum_t);
276
277 return py::cast<py::object>(dt);
278 }
279}
280
285template <typename output_typesT,
286 typename contig_dispatchT,
287 typename strided_dispatchT>
288std::pair<sycl::event, sycl::event>
289 py_unary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src,
290 const dpctl::tensor::usm_ndarray &dst1,
291 const dpctl::tensor::usm_ndarray &dst2,
292 sycl::queue &q,
293 const std::vector<sycl::event> &depends,
294 //
295 const output_typesT &output_type_vec,
296 const contig_dispatchT &contig_dispatch_vector,
297 const strided_dispatchT &strided_dispatch_vector)
298{
299 int src_typenum = src.get_typenum();
300 int dst1_typenum = dst1.get_typenum();
301 int dst2_typenum = dst2.get_typenum();
302
303 const auto &array_types = td_ns::usm_ndarray_types();
304 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
305 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
306 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
307
308 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
309
310 // check that types are supported
311 if (dst1_typeid != func_output_typeids.first ||
312 dst2_typeid != func_output_typeids.second) {
313 throw py::value_error(
314 "One of destination arrays has unexpected elemental data type.");
315 }
316
317 // check that queues are compatible
318 if (!dpctl::utils::queues_are_compatible(q, {src, dst1, dst2})) {
319 throw py::value_error(
320 "Execution queue is not compatible with allocation queues");
321 }
322
323 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
324 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
325
326 // check that dimensions are the same
327 int src_nd = src.get_ndim();
328 if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
329 throw py::value_error("Array dimensions are not the same.");
330 }
331
332 // check that shapes are the same
333 const py::ssize_t *src_shape = src.get_shape_raw();
334 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
335 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
336 bool shapes_equal(true);
337 std::size_t src_nelems(1);
338
339 for (int i = 0; i < src_nd; ++i) {
340 src_nelems *= static_cast<std::size_t>(src_shape[i]);
341 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
342 (src_shape[i] == dst2_shape[i]);
343 }
344 if (!shapes_equal) {
345 throw py::value_error("Array shapes are not the same.");
346 }
347
348 // if nelems is zero, return
349 if (src_nelems == 0) {
350 return std::make_pair(sycl::event(), sycl::event());
351 }
352
353 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
354 src_nelems);
355 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
356 src_nelems);
357
358 // check memory overlap
359 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
360 auto const &same_logical_tensors =
361 dpctl::tensor::overlap::SameLogicalTensors();
362 if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
363 (overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
364 (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2))) {
365 throw py::value_error("Arrays index overlapping segments of memory");
366 }
367
368 const char *src_data = src.get_data();
369 char *dst1_data = dst1.get_data();
370 char *dst2_data = dst2.get_data();
371
372 // handle contiguous inputs
373 bool is_src_c_contig = src.is_c_contiguous();
374 bool is_src_f_contig = src.is_f_contiguous();
375
376 bool is_dst1_c_contig = dst1.is_c_contiguous();
377 bool is_dst1_f_contig = dst1.is_f_contiguous();
378
379 bool is_dst2_c_contig = dst2.is_c_contiguous();
380 bool is_dst2_f_contig = dst2.is_f_contiguous();
381
382 bool all_c_contig =
383 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
384 bool all_f_contig =
385 (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
386
387 if (all_c_contig || all_f_contig) {
388 auto contig_fn = contig_dispatch_vector[src_typeid];
389
390 if (contig_fn == nullptr) {
391 throw std::runtime_error(
392 "Contiguous implementation is missing for src_typeid=" +
393 std::to_string(src_typeid));
394 }
395
396 auto comp_ev =
397 contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
398 sycl::event ht_ev =
399 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
400
401 return std::make_pair(ht_ev, comp_ev);
402 }
403
404 // simplify iteration space
405 // if 1d with strides 1 - input is contig
406 // dispatch to strided
407
408 auto const &src_strides = src.get_strides_vector();
409 auto const &dst1_strides = dst1.get_strides_vector();
410 auto const &dst2_strides = dst2.get_strides_vector();
411
412 using shT = std::vector<py::ssize_t>;
413 shT simplified_shape;
414 shT simplified_src_strides;
415 shT simplified_dst1_strides;
416 shT simplified_dst2_strides;
417 py::ssize_t src_offset(0);
418 py::ssize_t dst1_offset(0);
419 py::ssize_t dst2_offset(0);
420
421 int nd = src_nd;
422 const py::ssize_t *shape = src_shape;
423
424 simplify_iteration_space_3(
425 nd, shape, src_strides, dst1_strides, dst2_strides,
426 // output
427 simplified_shape, simplified_src_strides, simplified_dst1_strides,
428 simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
429
430 if (nd == 1 && simplified_src_strides[0] == 1 &&
431 simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1) {
432 // Special case of contiguous data
433 auto contig_fn = contig_dispatch_vector[src_typeid];
434
435 if (contig_fn == nullptr) {
436 throw std::runtime_error(
437 "Contiguous implementation is missing for src_typeid=" +
438 std::to_string(src_typeid));
439 }
440
441 int src_elem_size = src.get_elemsize();
442 int dst1_elem_size = dst1.get_elemsize();
443 int dst2_elem_size = dst2.get_elemsize();
444 auto comp_ev =
445 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
446 dst1_data + dst1_elem_size * dst1_offset,
447 dst2_data + dst2_elem_size * dst2_offset, depends);
448
449 sycl::event ht_ev =
450 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
451
452 return std::make_pair(ht_ev, comp_ev);
453 }
454
455 // Strided implementation
456 auto strided_fn = strided_dispatch_vector[src_typeid];
457
458 if (strided_fn == nullptr) {
459 throw std::runtime_error(
460 "Strided implementation is missing for src_typeid=" +
461 std::to_string(src_typeid));
462 }
463
464 using dpctl::tensor::offset_utils::device_allocate_and_pack;
465
466 std::vector<sycl::event> host_tasks{};
467 host_tasks.reserve(2);
468
469 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
470 q, host_tasks, simplified_shape, simplified_src_strides,
471 simplified_dst1_strides, simplified_dst2_strides);
472 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
473 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
474 const py::ssize_t *shape_strides = shape_strides_owner.get();
475
476 sycl::event strided_fn_ev = strided_fn(
477 q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
478 dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
479
480 // async free of shape_strides temporary
481 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
482 q, {strided_fn_ev}, shape_strides_owner);
483
484 host_tasks.push_back(tmp_cleanup_ev);
485
486 return std::make_pair(
487 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
488 strided_fn_ev);
489}
490
495template <typename output_typesT>
496std::pair<py::object, py::object>
497 py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
498 const output_typesT &output_types)
499{
500 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
501 int src_typeid = -1;
502
503 auto array_types = td_ns::usm_ndarray_types();
504
505 try {
506 src_typeid = array_types.typenum_to_lookup_id(tn);
507 } catch (const std::exception &e) {
508 throw py::value_error(e.what());
509 }
510
511 std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
512 int dst1_typeid = dst_typeids.first;
513 int dst2_typeid = dst_typeids.second;
514
515 if (dst1_typeid < 0 || dst2_typeid < 0) {
516 auto res = py::none();
517 auto py_res = py::cast<py::object>(res);
518 return std::make_pair(py_res, py_res);
519 }
520 else {
521 using type_utils::_dtype_from_typenum;
522
523 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
524 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
525 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
526 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
527
528 return std::make_pair(py::cast<py::object>(dt1),
529 py::cast<py::object>(dt2));
530 }
531}
532
533// ======================== Binary functions ===========================
534
535namespace
536{
537template <class Container, class T>
538bool isEqual(Container const &c, std::initializer_list<T> const &l)
539{
540 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
541}
542} // namespace
543
546template <typename output_typesT,
547 typename contig_dispatchT,
548 typename strided_dispatchT,
549 typename contig_matrix_row_dispatchT,
550 typename contig_row_matrix_dispatchT>
551std::pair<sycl::event, sycl::event> py_binary_ufunc(
552 const dpctl::tensor::usm_ndarray &src1,
553 const dpctl::tensor::usm_ndarray &src2,
554 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
555 sycl::queue &exec_q,
556 const std::vector<sycl::event> &depends,
557 //
558 const output_typesT &output_type_table,
559 const contig_dispatchT &contig_dispatch_table,
560 const strided_dispatchT &strided_dispatch_table,
561 const contig_matrix_row_dispatchT
562 &contig_matrix_row_broadcast_dispatch_table,
563 const contig_row_matrix_dispatchT
564 &contig_row_matrix_broadcast_dispatch_table)
565{
566 // check type_nums
567 int src1_typenum = src1.get_typenum();
568 int src2_typenum = src2.get_typenum();
569 int dst_typenum = dst.get_typenum();
570
571 auto array_types = td_ns::usm_ndarray_types();
572 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
573 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
574 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
575
576 int output_typeid = output_type_table[src1_typeid][src2_typeid];
577
578 if (output_typeid != dst_typeid) {
579 throw py::value_error(
580 "Destination array has unexpected elemental data type.");
581 }
582
583 // check that queues are compatible
584 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
585 throw py::value_error(
586 "Execution queue is not compatible with allocation queues");
587 }
588
589 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
590
591 // check shapes, broadcasting is assumed done by caller
592 // check that dimensions are the same
593 int dst_nd = dst.get_ndim();
594 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
595 throw py::value_error("Array dimensions are not the same.");
596 }
597
598 // check that shapes are the same
599 const py::ssize_t *src1_shape = src1.get_shape_raw();
600 const py::ssize_t *src2_shape = src2.get_shape_raw();
601 const py::ssize_t *dst_shape = dst.get_shape_raw();
602 bool shapes_equal(true);
603 std::size_t src_nelems(1);
604
605 for (int i = 0; i < dst_nd; ++i) {
606 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
607 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
608 src2_shape[i] == dst_shape[i]);
609 }
610 if (!shapes_equal) {
611 throw py::value_error("Array shapes are not the same.");
612 }
613
614 // if nelems is zero, return
615 if (src_nelems == 0) {
616 return std::make_pair(sycl::event(), sycl::event());
617 }
618
619 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
620
621 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
622 auto const &same_logical_tensors =
623 dpctl::tensor::overlap::SameLogicalTensors();
624 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
625 (overlap(src2, dst) && !same_logical_tensors(src2, dst))) {
626 throw py::value_error("Arrays index overlapping segments of memory");
627 }
628 // check memory overlap
629 const char *src1_data = src1.get_data();
630 const char *src2_data = src2.get_data();
631 char *dst_data = dst.get_data();
632
633 // handle contiguous inputs
634 bool is_src1_c_contig = src1.is_c_contiguous();
635 bool is_src1_f_contig = src1.is_f_contiguous();
636
637 bool is_src2_c_contig = src2.is_c_contiguous();
638 bool is_src2_f_contig = src2.is_f_contiguous();
639
640 bool is_dst_c_contig = dst.is_c_contiguous();
641 bool is_dst_f_contig = dst.is_f_contiguous();
642
643 bool all_c_contig =
644 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
645 bool all_f_contig =
646 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
647
648 // dispatch for contiguous inputs
649 if (all_c_contig || all_f_contig) {
650 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
651
652 if (contig_fn != nullptr) {
653 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
654 src2_data, 0, dst_data, 0, depends);
655 sycl::event ht_ev = dpctl::utils::keep_args_alive(
656 exec_q, {src1, src2, dst}, {comp_ev});
657
658 return std::make_pair(ht_ev, comp_ev);
659 }
660 }
661
662 // simplify strides
663 auto const &src1_strides = src1.get_strides_vector();
664 auto const &src2_strides = src2.get_strides_vector();
665 auto const &dst_strides = dst.get_strides_vector();
666
667 using shT = std::vector<py::ssize_t>;
668 shT simplified_shape;
669 shT simplified_src1_strides;
670 shT simplified_src2_strides;
671 shT simplified_dst_strides;
672 py::ssize_t src1_offset(0);
673 py::ssize_t src2_offset(0);
674 py::ssize_t dst_offset(0);
675
676 int nd = dst_nd;
677 const py::ssize_t *shape = src1_shape;
678
679 simplify_iteration_space_3(
680 nd, shape, src1_strides, src2_strides, dst_strides,
681 // outputs
682 simplified_shape, simplified_src1_strides, simplified_src2_strides,
683 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
684
685 std::vector<sycl::event> host_tasks{};
686 if (nd < 3) {
687 static constexpr auto unit_stride =
688 std::initializer_list<py::ssize_t>{1};
689
690 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
691 isEqual(simplified_src2_strides, unit_stride) &&
692 isEqual(simplified_dst_strides, unit_stride)) {
693 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
694
695 if (contig_fn != nullptr) {
696 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
697 src1_offset, src2_data, src2_offset,
698 dst_data, dst_offset, depends);
699 sycl::event ht_ev = dpctl::utils::keep_args_alive(
700 exec_q, {src1, src2, dst}, {comp_ev});
701
702 return std::make_pair(ht_ev, comp_ev);
703 }
704 }
705 if (nd == 2) {
706 static constexpr auto zero_one_strides =
707 std::initializer_list<py::ssize_t>{0, 1};
708 static constexpr auto one_zero_strides =
709 std::initializer_list<py::ssize_t>{1, 0};
710 static constexpr py::ssize_t one{1};
711 // special case of C-contiguous matrix and a row
712 if (isEqual(simplified_src2_strides, zero_one_strides) &&
713 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
714 isEqual(simplified_dst_strides, {simplified_shape[1], one})) {
715 auto matrix_row_broadcast_fn =
716 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
717 [src2_typeid];
718 if (matrix_row_broadcast_fn != nullptr) {
719 int src1_itemsize = src1.get_elemsize();
720 int src2_itemsize = src2.get_elemsize();
721 int dst_itemsize = dst.get_elemsize();
722
723 if (is_aligned<required_alignment>(
724 src1_data + src1_offset * src1_itemsize) &&
725 is_aligned<required_alignment>(
726 src2_data + src2_offset * src2_itemsize) &&
727 is_aligned<required_alignment>(
728 dst_data + dst_offset * dst_itemsize)) {
729 std::size_t n0 = simplified_shape[0];
730 std::size_t n1 = simplified_shape[1];
731 sycl::event comp_ev = matrix_row_broadcast_fn(
732 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
733 src2_data, src2_offset, dst_data, dst_offset,
734 depends);
735
736 return std::make_pair(
737 dpctl::utils::keep_args_alive(
738 exec_q, {src1, src2, dst}, host_tasks),
739 comp_ev);
740 }
741 }
742 }
743 if (isEqual(simplified_src1_strides, one_zero_strides) &&
744 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
745 isEqual(simplified_dst_strides, {one, simplified_shape[0]})) {
746 auto row_matrix_broadcast_fn =
747 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
748 [src2_typeid];
749 if (row_matrix_broadcast_fn != nullptr) {
750
751 int src1_itemsize = src1.get_elemsize();
752 int src2_itemsize = src2.get_elemsize();
753 int dst_itemsize = dst.get_elemsize();
754
755 if (is_aligned<required_alignment>(
756 src1_data + src1_offset * src1_itemsize) &&
757 is_aligned<required_alignment>(
758 src2_data + src2_offset * src2_itemsize) &&
759 is_aligned<required_alignment>(
760 dst_data + dst_offset * dst_itemsize)) {
761 std::size_t n0 = simplified_shape[1];
762 std::size_t n1 = simplified_shape[0];
763 sycl::event comp_ev = row_matrix_broadcast_fn(
764 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
765 src2_data, src2_offset, dst_data, dst_offset,
766 depends);
767
768 return std::make_pair(
769 dpctl::utils::keep_args_alive(
770 exec_q, {src1, src2, dst}, host_tasks),
771 comp_ev);
772 }
773 }
774 }
775 }
776 }
777
778 // dispatch to strided code
779 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
780
781 if (strided_fn == nullptr) {
782 throw std::runtime_error(
783 "Strided implementation is missing for src1_typeid=" +
784 std::to_string(src1_typeid) +
785 " and src2_typeid=" + std::to_string(src2_typeid));
786 }
787
788 using dpctl::tensor::offset_utils::device_allocate_and_pack;
789 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
790 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
791 simplified_src2_strides, simplified_dst_strides);
792 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
793 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
794
795 const py::ssize_t *shape_strides = shape_strides_owner.get();
796
797 sycl::event strided_fn_ev = strided_fn(
798 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
799 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
800
801 // async free of shape_strides temporary
802 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
803 exec_q, {strided_fn_ev}, shape_strides_owner);
804
805 host_tasks.push_back(tmp_cleanup_ev);
806
807 return std::make_pair(
808 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
809 strided_fn_ev);
810}
811
813template <typename output_typesT>
814py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
815 const py::dtype &input2_dtype,
816 const output_typesT &output_types_table)
817{
818 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
819 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
820 int src1_typeid = -1;
821 int src2_typeid = -1;
822
823 auto array_types = td_ns::usm_ndarray_types();
824
825 try {
826 src1_typeid = array_types.typenum_to_lookup_id(tn1);
827 src2_typeid = array_types.typenum_to_lookup_id(tn2);
828 } catch (const std::exception &e) {
829 throw py::value_error(e.what());
830 }
831
832 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
833 src2_typeid >= td_ns::num_types) {
834 throw std::runtime_error("binary output type lookup failed");
835 }
836 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
837
838 if (dst_typeid < 0) {
839 auto res = py::none();
840 return py::cast<py::object>(res);
841 }
842 else {
843 using type_utils::_dtype_from_typenum;
844
845 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
846 auto dt = _dtype_from_typenum(dst_typenum_t);
847
848 return py::cast<py::object>(dt);
849 }
850}
851
854template <typename output_typesT,
855 typename contig_dispatchT,
856 typename strided_dispatchT>
857std::pair<sycl::event, sycl::event>
858 py_binary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src1,
859 const dpctl::tensor::usm_ndarray &src2,
860 const dpctl::tensor::usm_ndarray &dst1,
861 const dpctl::tensor::usm_ndarray &dst2,
862 sycl::queue &exec_q,
863 const std::vector<sycl::event> &depends,
864 //
865 const output_typesT &output_types_table,
866 const contig_dispatchT &contig_dispatch_table,
867 const strided_dispatchT &strided_dispatch_table)
868{
869 // check type_nums
870 int src1_typenum = src1.get_typenum();
871 int src2_typenum = src2.get_typenum();
872 int dst1_typenum = dst1.get_typenum();
873 int dst2_typenum = dst2.get_typenum();
874
875 auto array_types = td_ns::usm_ndarray_types();
876 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
877 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
878 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
879 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
880
881 std::pair<int, int> output_typeids =
882 output_types_table[src1_typeid][src2_typeid];
883
884 if (dst1_typeid != output_typeids.first ||
885 dst2_typeid != output_typeids.second) {
886 throw py::value_error(
887 "One of destination arrays has unexpected elemental data type.");
888 }
889
890 // check that queues are compatible
891 if (!dpctl::utils::queues_are_compatible(exec_q,
892 {src1, src2, dst1, dst2})) {
893 throw py::value_error(
894 "Execution queue is not compatible with allocation queues");
895 }
896
897 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
898 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
899
900 // check shapes, broadcasting is assumed done by caller
901 // check that dimensions are the same
902 int src1_nd = src1.get_ndim();
903 int src2_nd = src2.get_ndim();
904 int dst1_nd = dst1.get_ndim();
905 int dst2_nd = dst2.get_ndim();
906
907 if (dst1_nd != src1_nd || dst1_nd != src2_nd || dst1_nd != dst2_nd) {
908 throw py::value_error("Array dimensions are not the same.");
909 }
910
911 // check that shapes are the same
912 const py::ssize_t *src1_shape = src1.get_shape_raw();
913 const py::ssize_t *src2_shape = src2.get_shape_raw();
914 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
915 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
916 bool shapes_equal(true);
917 std::size_t src_nelems(1);
918
919 for (int i = 0; i < dst1_nd; ++i) {
920 const auto &sh_i = dst1_shape[i];
921 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
922 shapes_equal =
923 shapes_equal && (src1_shape[i] == sh_i && src2_shape[i] == sh_i &&
924 dst2_shape[i] == sh_i);
925 }
926 if (!shapes_equal) {
927 throw py::value_error("Array shapes are not the same.");
928 }
929
930 // if nelems is zero, return
931 if (src_nelems == 0) {
932 return std::make_pair(sycl::event(), sycl::event());
933 }
934
935 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
936 src_nelems);
937 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
938 src_nelems);
939
940 // check memory overlap
941 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
942 auto const &same_logical_tensors =
943 dpctl::tensor::overlap::SameLogicalTensors();
944 if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) ||
945 (overlap(src1, dst2) && !same_logical_tensors(src1, dst2)) ||
946 (overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) ||
947 (overlap(src2, dst2) && !same_logical_tensors(src2, dst2)) ||
948 (overlap(dst1, dst2))) {
949 throw py::value_error("Arrays index overlapping segments of memory");
950 }
951
952 const char *src1_data = src1.get_data();
953 const char *src2_data = src2.get_data();
954 char *dst1_data = dst1.get_data();
955 char *dst2_data = dst2.get_data();
956
957 // handle contiguous inputs
958 bool is_src1_c_contig = src1.is_c_contiguous();
959 bool is_src1_f_contig = src1.is_f_contiguous();
960
961 bool is_src2_c_contig = src2.is_c_contiguous();
962 bool is_src2_f_contig = src2.is_f_contiguous();
963
964 bool is_dst1_c_contig = dst1.is_c_contiguous();
965 bool is_dst1_f_contig = dst1.is_f_contiguous();
966
967 bool is_dst2_c_contig = dst2.is_c_contiguous();
968 bool is_dst2_f_contig = dst2.is_f_contiguous();
969
970 bool all_c_contig = (is_src1_c_contig && is_src2_c_contig &&
971 is_dst1_c_contig && is_dst2_c_contig);
972 bool all_f_contig = (is_src1_f_contig && is_src2_f_contig &&
973 is_dst1_f_contig && is_dst2_f_contig);
974
975 // dispatch for contiguous inputs
976 if (all_c_contig || all_f_contig) {
977 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
978
979 if (contig_fn != nullptr) {
980 auto comp_ev =
981 contig_fn(exec_q, src_nelems, src1_data, 0, src2_data, 0,
982 dst1_data, 0, dst2_data, 0, depends);
983 sycl::event ht_ev = dpctl::utils::keep_args_alive(
984 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
985
986 return std::make_pair(ht_ev, comp_ev);
987 }
988 }
989
990 // simplify strides
991 auto const &src1_strides = src1.get_strides_vector();
992 auto const &src2_strides = src2.get_strides_vector();
993 auto const &dst1_strides = dst1.get_strides_vector();
994 auto const &dst2_strides = dst2.get_strides_vector();
995
996 using shT = std::vector<py::ssize_t>;
997 shT simplified_shape;
998 shT simplified_src1_strides;
999 shT simplified_src2_strides;
1000 shT simplified_dst1_strides;
1001 shT simplified_dst2_strides;
1002 py::ssize_t src1_offset(0);
1003 py::ssize_t src2_offset(0);
1004 py::ssize_t dst1_offset(0);
1005 py::ssize_t dst2_offset(0);
1006
1007 int nd = dst1_nd;
1008 const py::ssize_t *shape = src1_shape;
1009
1010 simplify_iteration_space_4(
1011 nd, shape, src1_strides, src2_strides, dst1_strides, dst2_strides,
1012 // outputs
1013 simplified_shape, simplified_src1_strides, simplified_src2_strides,
1014 simplified_dst1_strides, simplified_dst2_strides, src1_offset,
1015 src2_offset, dst1_offset, dst2_offset);
1016
1017 std::vector<sycl::event> host_tasks{};
1018 static constexpr auto unit_stride = std::initializer_list<py::ssize_t>{1};
1019
1020 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
1021 isEqual(simplified_src2_strides, unit_stride) &&
1022 isEqual(simplified_dst1_strides, unit_stride) &&
1023 isEqual(simplified_dst2_strides, unit_stride)) {
1024 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
1025
1026 if (contig_fn != nullptr) {
1027 auto comp_ev =
1028 contig_fn(exec_q, src_nelems, src1_data, src1_offset, src2_data,
1029 src2_offset, dst1_data, dst1_offset, dst2_data,
1030 dst2_offset, depends);
1031 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1032 exec_q, {src1, src2, dst1, dst2}, {comp_ev});
1033
1034 return std::make_pair(ht_ev, comp_ev);
1035 }
1036 }
1037
1038 // dispatch to strided code
1039 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
1040
1041 if (strided_fn == nullptr) {
1042 throw std::runtime_error(
1043 "Strided implementation is missing for src1_typeid=" +
1044 std::to_string(src1_typeid) +
1045 " and src2_typeid=" + std::to_string(src2_typeid));
1046 }
1047
1048 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1049 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1050 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
1051 simplified_src2_strides, simplified_dst1_strides,
1052 simplified_dst2_strides);
1053 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1054 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1055
1056 const py::ssize_t *shape_strides = shape_strides_owner.get();
1057
1058 sycl::event strided_fn_ev =
1059 strided_fn(exec_q, src_nelems, nd, shape_strides, src1_data,
1060 src1_offset, src2_data, src2_offset, dst1_data, dst1_offset,
1061 dst2_data, dst2_offset, depends, {copy_shape_ev});
1062
1063 // async free of shape_strides temporary
1064 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1065 exec_q, {strided_fn_ev}, shape_strides_owner);
1066 host_tasks.push_back(tmp_cleanup_ev);
1067
1068 return std::make_pair(dpctl::utils::keep_args_alive(
1069 exec_q, {src1, src2, dst1, dst2}, host_tasks),
1070 strided_fn_ev);
1071}
1072
1077template <typename output_typesT>
1078std::pair<py::object, py::object> py_binary_two_outputs_ufunc_result_type(
1079 const py::dtype &input1_dtype,
1080 const py::dtype &input2_dtype,
1081 const output_typesT &output_types_table)
1082{
1083 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
1084 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
1085 int src1_typeid = -1;
1086 int src2_typeid = -1;
1087
1088 auto array_types = td_ns::usm_ndarray_types();
1089
1090 try {
1091 src1_typeid = array_types.typenum_to_lookup_id(tn1);
1092 src2_typeid = array_types.typenum_to_lookup_id(tn2);
1093 } catch (const std::exception &e) {
1094 throw py::value_error(e.what());
1095 }
1096
1097 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
1098 src2_typeid >= td_ns::num_types) {
1099 throw std::runtime_error("binary output type lookup failed");
1100 }
1101 std::pair<int, int> dst_typeids =
1102 output_types_table[src1_typeid][src2_typeid];
1103 int dst1_typeid = dst_typeids.first;
1104 int dst2_typeid = dst_typeids.second;
1105
1106 if (dst1_typeid < 0 || dst2_typeid < 0) {
1107 auto res = py::none();
1108 auto py_res = py::cast<py::object>(res);
1109 return std::make_pair(py_res, py_res);
1110 }
1111 else {
1112 using type_utils::_dtype_from_typenum;
1113
1114 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
1115 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
1116 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
1117 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
1118
1119 return std::make_pair(py::cast<py::object>(dt1),
1120 py::cast<py::object>(dt2));
1121 }
1122}
1123
1124// ==================== Inplace binary functions =======================
1125
1126template <typename output_typesT,
1127 typename contig_dispatchT,
1128 typename strided_dispatchT,
1129 typename contig_row_matrix_dispatchT>
1130std::pair<sycl::event, sycl::event>
1131 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
1132 const dpctl::tensor::usm_ndarray &rhs,
1133 sycl::queue &exec_q,
1134 const std::vector<sycl::event> &depends,
1135 //
1136 const output_typesT &output_type_table,
1137 const contig_dispatchT &contig_dispatch_table,
1138 const strided_dispatchT &strided_dispatch_table,
1139 const contig_row_matrix_dispatchT
1140 &contig_row_matrix_broadcast_dispatch_table)
1141{
1142 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
1143
1144 // check type_nums
1145 int rhs_typenum = rhs.get_typenum();
1146 int lhs_typenum = lhs.get_typenum();
1147
1148 auto array_types = td_ns::usm_ndarray_types();
1149 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
1150 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
1151
1152 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
1153
1154 if (output_typeid != lhs_typeid) {
1155 throw py::value_error(
1156 "Left-hand side array has unexpected elemental data type.");
1157 }
1158
1159 // check that queues are compatible
1160 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
1161 throw py::value_error(
1162 "Execution queue is not compatible with allocation queues");
1163 }
1164
1165 // check shapes, broadcasting is assumed done by caller
1166 // check that dimensions are the same
1167 int lhs_nd = lhs.get_ndim();
1168 if (lhs_nd != rhs.get_ndim()) {
1169 throw py::value_error("Array dimensions are not the same.");
1170 }
1171
1172 // check that shapes are the same
1173 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
1174 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
1175 bool shapes_equal(true);
1176 std::size_t rhs_nelems(1);
1177
1178 for (int i = 0; i < lhs_nd; ++i) {
1179 rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
1180 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
1181 }
1182 if (!shapes_equal) {
1183 throw py::value_error("Array shapes are not the same.");
1184 }
1185
1186 // if nelems is zero, return
1187 if (rhs_nelems == 0) {
1188 return std::make_pair(sycl::event(), sycl::event());
1189 }
1190
1191 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
1192
1193 // check memory overlap
1194 auto const &same_logical_tensors =
1195 dpctl::tensor::overlap::SameLogicalTensors();
1196 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
1197 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
1198 throw py::value_error("Arrays index overlapping segments of memory");
1199 }
1200 // check memory overlap
1201 const char *rhs_data = rhs.get_data();
1202 char *lhs_data = lhs.get_data();
1203
1204 // handle contiguous inputs
1205 bool is_rhs_c_contig = rhs.is_c_contiguous();
1206 bool is_rhs_f_contig = rhs.is_f_contiguous();
1207
1208 bool is_lhs_c_contig = lhs.is_c_contiguous();
1209 bool is_lhs_f_contig = lhs.is_f_contiguous();
1210
1211 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
1212 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
1213
1214 // dispatch for contiguous inputs
1215 if (both_c_contig || both_f_contig) {
1216 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1217
1218 if (contig_fn != nullptr) {
1219 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
1220 0, depends);
1221 sycl::event ht_ev =
1222 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
1223
1224 return std::make_pair(ht_ev, comp_ev);
1225 }
1226 }
1227
1228 // simplify strides
1229 auto const &rhs_strides = rhs.get_strides_vector();
1230 auto const &lhs_strides = lhs.get_strides_vector();
1231
1232 using shT = std::vector<py::ssize_t>;
1233 shT simplified_shape;
1234 shT simplified_rhs_strides;
1235 shT simplified_lhs_strides;
1236 py::ssize_t rhs_offset(0);
1237 py::ssize_t lhs_offset(0);
1238
1239 int nd = lhs_nd;
1240 const py::ssize_t *shape = rhs_shape;
1241
1242 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
1243 // outputs
1244 simplified_shape, simplified_rhs_strides,
1245 simplified_lhs_strides, rhs_offset, lhs_offset);
1246
1247 std::vector<sycl::event> host_tasks{};
1248 if (nd < 3) {
1249 static constexpr auto unit_stride =
1250 std::initializer_list<py::ssize_t>{1};
1251
1252 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
1253 isEqual(simplified_lhs_strides, unit_stride)) {
1254 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
1255
1256 if (contig_fn != nullptr) {
1257 auto comp_ev =
1258 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
1259 lhs_data, lhs_offset, depends);
1260 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1261 exec_q, {rhs, lhs}, {comp_ev});
1262
1263 return std::make_pair(ht_ev, comp_ev);
1264 }
1265 }
1266 if (nd == 2) {
1267 static constexpr auto one_zero_strides =
1268 std::initializer_list<py::ssize_t>{1, 0};
1269 static constexpr py::ssize_t one{1};
1270 // special case of C-contiguous matrix and a row
1271 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
1272 isEqual(simplified_lhs_strides, {one, simplified_shape[0]})) {
1273 auto row_matrix_broadcast_fn =
1274 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
1275 [lhs_typeid];
1276 if (row_matrix_broadcast_fn != nullptr) {
1277 std::size_t n0 = simplified_shape[1];
1278 std::size_t n1 = simplified_shape[0];
1279 sycl::event comp_ev = row_matrix_broadcast_fn(
1280 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
1281 lhs_data, lhs_offset, depends);
1282
1283 return std::make_pair(dpctl::utils::keep_args_alive(
1284 exec_q, {lhs, rhs}, host_tasks),
1285 comp_ev);
1286 }
1287 }
1288 }
1289 }
1290
1291 // dispatch to strided code
1292 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
1293
1294 if (strided_fn == nullptr) {
1295 throw std::runtime_error(
1296 "Strided implementation is missing for rhs_typeid=" +
1297 std::to_string(rhs_typeid) +
1298 " and lhs_typeid=" + std::to_string(lhs_typeid));
1299 }
1300
1301 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1302 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1303 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
1304 simplified_lhs_strides);
1305 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1306 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1307
1308 const py::ssize_t *shape_strides = shape_strides_owner.get();
1309
1310 sycl::event strided_fn_ev =
1311 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
1312 lhs_data, lhs_offset, depends, {copy_shape_ev});
1313
1314 // async free of shape_strides temporary
1315 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1316 exec_q, {strided_fn_ev}, shape_strides_owner);
1317
1318 host_tasks.push_back(tmp_cleanup_ev);
1319
1320 return std::make_pair(
1321 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
1322 strided_fn_ev);
1323}
1324} // namespace dpnp::extensions::py_internal