DPNP C++ backend kernel library 0.18.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <exception>
29#include <stdexcept>
30
31#include <sycl/sycl.hpp>
32
33#include "dpctl4pybind11.hpp"
34#include <pybind11/numpy.h>
35#include <pybind11/pybind11.h>
36#include <pybind11/stl.h>
37
38#include "elementwise_functions_type_utils.hpp"
39#include "simplify_iteration_space.hpp"
40
41// dpctl tensor headers
42#include "kernels/alignment.hpp"
43// #include "kernels/dpctl_tensor_types.hpp"
44#include "utils/memory_overlap.hpp"
45#include "utils/offset_utils.hpp"
46#include "utils/output_validation.hpp"
47#include "utils/sycl_alloc_utils.hpp"
48#include "utils/type_dispatch.hpp"
49
50namespace py = pybind11;
51namespace td_ns = dpctl::tensor::type_dispatch;
52
53static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
54
55namespace dpnp::extensions::py_internal
56{
57
58using dpctl::tensor::kernels::alignment_utils::is_aligned;
59using dpctl::tensor::kernels::alignment_utils::required_alignment;
60
62template <typename output_typesT,
63 typename contig_dispatchT,
64 typename strided_dispatchT>
65std::pair<sycl::event, sycl::event>
66 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
67 const dpctl::tensor::usm_ndarray &dst,
68 sycl::queue &q,
69 const std::vector<sycl::event> &depends,
70 //
71 const output_typesT &output_type_vec,
72 const contig_dispatchT &contig_dispatch_vector,
73 const strided_dispatchT &strided_dispatch_vector)
74{
75 int src_typenum = src.get_typenum();
76 int dst_typenum = dst.get_typenum();
77
78 const auto &array_types = td_ns::usm_ndarray_types();
79 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
80 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
81
82 int func_output_typeid = output_type_vec[src_typeid];
83
84 // check that types are supported
85 if (dst_typeid != func_output_typeid) {
86 throw py::value_error(
87 "Destination array has unexpected elemental data type.");
88 }
89
90 // check that queues are compatible
91 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
92 throw py::value_error(
93 "Execution queue is not compatible with allocation queues");
94 }
95
96 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
97
98 // check that dimensions are the same
99 int src_nd = src.get_ndim();
100 if (src_nd != dst.get_ndim()) {
101 throw py::value_error("Array dimensions are not the same.");
102 }
103
104 // check that shapes are the same
105 const py::ssize_t *src_shape = src.get_shape_raw();
106 const py::ssize_t *dst_shape = dst.get_shape_raw();
107 bool shapes_equal(true);
108 size_t src_nelems(1);
109
110 for (int i = 0; i < src_nd; ++i) {
111 src_nelems *= static_cast<size_t>(src_shape[i]);
112 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
113 }
114 if (!shapes_equal) {
115 throw py::value_error("Array shapes are not the same.");
116 }
117
118 // if nelems is zero, return
119 if (src_nelems == 0) {
120 return std::make_pair(sycl::event(), sycl::event());
121 }
122
123 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
124
125 // check memory overlap
126 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
127 auto const &same_logical_tensors =
128 dpctl::tensor::overlap::SameLogicalTensors();
129 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
130 throw py::value_error("Arrays index overlapping segments of memory");
131 }
132
133 const char *src_data = src.get_data();
134 char *dst_data = dst.get_data();
135
136 // handle contiguous inputs
137 bool is_src_c_contig = src.is_c_contiguous();
138 bool is_src_f_contig = src.is_f_contiguous();
139
140 bool is_dst_c_contig = dst.is_c_contiguous();
141 bool is_dst_f_contig = dst.is_f_contiguous();
142
143 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
144 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
145
146 if (both_c_contig || both_f_contig) {
147 auto contig_fn = contig_dispatch_vector[src_typeid];
148
149 if (contig_fn == nullptr) {
150 throw std::runtime_error(
151 "Contiguous implementation is missing for src_typeid=" +
152 std::to_string(src_typeid));
153 }
154
155 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
156 sycl::event ht_ev =
157 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
158
159 return std::make_pair(ht_ev, comp_ev);
160 }
161
162 // simplify iteration space
163 // if 1d with strides 1 - input is contig
164 // dispatch to strided
165
166 auto const &src_strides = src.get_strides_vector();
167 auto const &dst_strides = dst.get_strides_vector();
168
169 using shT = std::vector<py::ssize_t>;
170 shT simplified_shape;
171 shT simplified_src_strides;
172 shT simplified_dst_strides;
173 py::ssize_t src_offset(0);
174 py::ssize_t dst_offset(0);
175
176 int nd = src_nd;
177 const py::ssize_t *shape = src_shape;
178
179 simplify_iteration_space(nd, shape, src_strides, dst_strides,
180 // output
181 simplified_shape, simplified_src_strides,
182 simplified_dst_strides, src_offset, dst_offset);
183
184 if (nd == 1 && simplified_src_strides[0] == 1 &&
185 simplified_dst_strides[0] == 1) {
186 // Special case of contiguous data
187 auto contig_fn = contig_dispatch_vector[src_typeid];
188
189 if (contig_fn == nullptr) {
190 throw std::runtime_error(
191 "Contiguous implementation is missing for src_typeid=" +
192 std::to_string(src_typeid));
193 }
194
195 int src_elem_size = src.get_elemsize();
196 int dst_elem_size = dst.get_elemsize();
197 auto comp_ev =
198 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
199 dst_data + dst_elem_size * dst_offset, depends);
200
201 sycl::event ht_ev =
202 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
203
204 return std::make_pair(ht_ev, comp_ev);
205 }
206
207 // Strided implementation
208 auto strided_fn = strided_dispatch_vector[src_typeid];
209
210 if (strided_fn == nullptr) {
211 throw std::runtime_error(
212 "Strided implementation is missing for src_typeid=" +
213 std::to_string(src_typeid));
214 }
215
216 using dpctl::tensor::offset_utils::device_allocate_and_pack;
217
218 std::vector<sycl::event> host_tasks{};
219 host_tasks.reserve(2);
220
221 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
222 q, host_tasks, simplified_shape, simplified_src_strides,
223 simplified_dst_strides);
224 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
225 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
226 const py::ssize_t *shape_strides = shape_strides_owner.get();
227
228 sycl::event strided_fn_ev =
229 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
230 dst_data, dst_offset, depends, {copy_shape_ev});
231
232 // async free of shape_strides temporary
233 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
234 q, {strided_fn_ev}, shape_strides_owner);
235
236 host_tasks.push_back(tmp_cleanup_ev);
237
238 return std::make_pair(
239 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
240 strided_fn_ev);
241}
242
245template <typename output_typesT>
246py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
247 const output_typesT &output_types)
248{
249 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
250 int src_typeid = -1;
251
252 auto array_types = td_ns::usm_ndarray_types();
253
254 try {
255 src_typeid = array_types.typenum_to_lookup_id(tn);
256 } catch (const std::exception &e) {
257 throw py::value_error(e.what());
258 }
259
260 using type_utils::_result_typeid;
261 int dst_typeid = _result_typeid(src_typeid, output_types);
262
263 if (dst_typeid < 0) {
264 auto res = py::none();
265 return py::cast<py::object>(res);
266 }
267 else {
268 using type_utils::_dtype_from_typenum;
269
270 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
271 auto dt = _dtype_from_typenum(dst_typenum_t);
272
273 return py::cast<py::object>(dt);
274 }
275}
276
277// ======================== Binary functions ===========================
278
279namespace
280{
281template <class Container, class T>
282bool isEqual(Container const &c, std::initializer_list<T> const &l)
283{
284 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
285}
286} // namespace
287
290template <typename output_typesT,
291 typename contig_dispatchT,
292 typename strided_dispatchT,
293 typename contig_matrix_row_dispatchT,
294 typename contig_row_matrix_dispatchT>
295std::pair<sycl::event, sycl::event> py_binary_ufunc(
296 const dpctl::tensor::usm_ndarray &src1,
297 const dpctl::tensor::usm_ndarray &src2,
298 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
299 sycl::queue &exec_q,
300 const std::vector<sycl::event> depends,
301 //
302 const output_typesT &output_type_table,
303 const contig_dispatchT &contig_dispatch_table,
304 const strided_dispatchT &strided_dispatch_table,
305 const contig_matrix_row_dispatchT
306 &contig_matrix_row_broadcast_dispatch_table,
307 const contig_row_matrix_dispatchT
308 &contig_row_matrix_broadcast_dispatch_table)
309{
310 // check type_nums
311 int src1_typenum = src1.get_typenum();
312 int src2_typenum = src2.get_typenum();
313 int dst_typenum = dst.get_typenum();
314
315 auto array_types = td_ns::usm_ndarray_types();
316 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
317 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
318 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
319
320 int output_typeid = output_type_table[src1_typeid][src2_typeid];
321
322 if (output_typeid != dst_typeid) {
323 throw py::value_error(
324 "Destination array has unexpected elemental data type.");
325 }
326
327 // check that queues are compatible
328 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
329 throw py::value_error(
330 "Execution queue is not compatible with allocation queues");
331 }
332
333 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
334
335 // check shapes, broadcasting is assumed done by caller
336 // check that dimensions are the same
337 int dst_nd = dst.get_ndim();
338 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
339 throw py::value_error("Array dimensions are not the same.");
340 }
341
342 // check that shapes are the same
343 const py::ssize_t *src1_shape = src1.get_shape_raw();
344 const py::ssize_t *src2_shape = src2.get_shape_raw();
345 const py::ssize_t *dst_shape = dst.get_shape_raw();
346 bool shapes_equal(true);
347 size_t src_nelems(1);
348
349 for (int i = 0; i < dst_nd; ++i) {
350 src_nelems *= static_cast<size_t>(src1_shape[i]);
351 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
352 src2_shape[i] == dst_shape[i]);
353 }
354 if (!shapes_equal) {
355 throw py::value_error("Array shapes are not the same.");
356 }
357
358 // if nelems is zero, return
359 if (src_nelems == 0) {
360 return std::make_pair(sycl::event(), sycl::event());
361 }
362
363 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
364
365 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
366 auto const &same_logical_tensors =
367 dpctl::tensor::overlap::SameLogicalTensors();
368 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
369 (overlap(src2, dst) && !same_logical_tensors(src2, dst)))
370 {
371 throw py::value_error("Arrays index overlapping segments of memory");
372 }
373 // check memory overlap
374 const char *src1_data = src1.get_data();
375 const char *src2_data = src2.get_data();
376 char *dst_data = dst.get_data();
377
378 // handle contiguous inputs
379 bool is_src1_c_contig = src1.is_c_contiguous();
380 bool is_src1_f_contig = src1.is_f_contiguous();
381
382 bool is_src2_c_contig = src2.is_c_contiguous();
383 bool is_src2_f_contig = src2.is_f_contiguous();
384
385 bool is_dst_c_contig = dst.is_c_contiguous();
386 bool is_dst_f_contig = dst.is_f_contiguous();
387
388 bool all_c_contig =
389 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
390 bool all_f_contig =
391 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
392
393 // dispatch for contiguous inputs
394 if (all_c_contig || all_f_contig) {
395 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
396
397 if (contig_fn != nullptr) {
398 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
399 src2_data, 0, dst_data, 0, depends);
400 sycl::event ht_ev = dpctl::utils::keep_args_alive(
401 exec_q, {src1, src2, dst}, {comp_ev});
402
403 return std::make_pair(ht_ev, comp_ev);
404 }
405 }
406
407 // simplify strides
408 auto const &src1_strides = src1.get_strides_vector();
409 auto const &src2_strides = src2.get_strides_vector();
410 auto const &dst_strides = dst.get_strides_vector();
411
412 using shT = std::vector<py::ssize_t>;
413 shT simplified_shape;
414 shT simplified_src1_strides;
415 shT simplified_src2_strides;
416 shT simplified_dst_strides;
417 py::ssize_t src1_offset(0);
418 py::ssize_t src2_offset(0);
419 py::ssize_t dst_offset(0);
420
421 int nd = dst_nd;
422 const py::ssize_t *shape = src1_shape;
423
424 simplify_iteration_space_3(
425 nd, shape, src1_strides, src2_strides, dst_strides,
426 // outputs
427 simplified_shape, simplified_src1_strides, simplified_src2_strides,
428 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
429
430 std::vector<sycl::event> host_tasks{};
431 if (nd < 3) {
432 static constexpr auto unit_stride =
433 std::initializer_list<py::ssize_t>{1};
434
435 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
436 isEqual(simplified_src2_strides, unit_stride) &&
437 isEqual(simplified_dst_strides, unit_stride))
438 {
439 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
440
441 if (contig_fn != nullptr) {
442 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
443 src1_offset, src2_data, src2_offset,
444 dst_data, dst_offset, depends);
445 sycl::event ht_ev = dpctl::utils::keep_args_alive(
446 exec_q, {src1, src2, dst}, {comp_ev});
447
448 return std::make_pair(ht_ev, comp_ev);
449 }
450 }
451 if (nd == 2) {
452 static constexpr auto zero_one_strides =
453 std::initializer_list<py::ssize_t>{0, 1};
454 static constexpr auto one_zero_strides =
455 std::initializer_list<py::ssize_t>{1, 0};
456 constexpr py::ssize_t one{1};
457 // special case of C-contiguous matrix and a row
458 if (isEqual(simplified_src2_strides, zero_one_strides) &&
459 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
460 isEqual(simplified_dst_strides, {simplified_shape[1], one}))
461 {
462 auto matrix_row_broadcast_fn =
463 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
464 [src2_typeid];
465 if (matrix_row_broadcast_fn != nullptr) {
466 int src1_itemsize = src1.get_elemsize();
467 int src2_itemsize = src2.get_elemsize();
468 int dst_itemsize = dst.get_elemsize();
469
470 if (is_aligned<required_alignment>(
471 src1_data + src1_offset * src1_itemsize) &&
472 is_aligned<required_alignment>(
473 src2_data + src2_offset * src2_itemsize) &&
474 is_aligned<required_alignment>(
475 dst_data + dst_offset * dst_itemsize))
476 {
477 size_t n0 = simplified_shape[0];
478 size_t n1 = simplified_shape[1];
479 sycl::event comp_ev = matrix_row_broadcast_fn(
480 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
481 src2_data, src2_offset, dst_data, dst_offset,
482 depends);
483
484 return std::make_pair(
485 dpctl::utils::keep_args_alive(
486 exec_q, {src1, src2, dst}, host_tasks),
487 comp_ev);
488 }
489 }
490 }
491 if (isEqual(simplified_src1_strides, one_zero_strides) &&
492 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
493 isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
494 {
495 auto row_matrix_broadcast_fn =
496 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
497 [src2_typeid];
498 if (row_matrix_broadcast_fn != nullptr) {
499
500 int src1_itemsize = src1.get_elemsize();
501 int src2_itemsize = src2.get_elemsize();
502 int dst_itemsize = dst.get_elemsize();
503
504 if (is_aligned<required_alignment>(
505 src1_data + src1_offset * src1_itemsize) &&
506 is_aligned<required_alignment>(
507 src2_data + src2_offset * src2_itemsize) &&
508 is_aligned<required_alignment>(
509 dst_data + dst_offset * dst_itemsize))
510 {
511 size_t n0 = simplified_shape[1];
512 size_t n1 = simplified_shape[0];
513 sycl::event comp_ev = row_matrix_broadcast_fn(
514 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
515 src2_data, src2_offset, dst_data, dst_offset,
516 depends);
517
518 return std::make_pair(
519 dpctl::utils::keep_args_alive(
520 exec_q, {src1, src2, dst}, host_tasks),
521 comp_ev);
522 }
523 }
524 }
525 }
526 }
527
528 // dispatch to strided code
529 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
530
531 if (strided_fn == nullptr) {
532 throw std::runtime_error(
533 "Strided implementation is missing for src1_typeid=" +
534 std::to_string(src1_typeid) +
535 " and src2_typeid=" + std::to_string(src2_typeid));
536 }
537
538 using dpctl::tensor::offset_utils::device_allocate_and_pack;
539 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
540 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
541 simplified_src2_strides, simplified_dst_strides);
542 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
543 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
544
545 const py::ssize_t *shape_strides = shape_strides_owner.get();
546
547 sycl::event strided_fn_ev = strided_fn(
548 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
549 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
550
551 // async free of shape_strides temporary
552 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
553 exec_q, {strided_fn_ev}, shape_strides_owner);
554
555 host_tasks.push_back(tmp_cleanup_ev);
556
557 return std::make_pair(
558 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
559 strided_fn_ev);
560}
561
563template <typename output_typesT>
564py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
565 const py::dtype &input2_dtype,
566 const output_typesT &output_types_table)
567{
568 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
569 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
570 int src1_typeid = -1;
571 int src2_typeid = -1;
572
573 auto array_types = td_ns::usm_ndarray_types();
574
575 try {
576 src1_typeid = array_types.typenum_to_lookup_id(tn1);
577 src2_typeid = array_types.typenum_to_lookup_id(tn2);
578 } catch (const std::exception &e) {
579 throw py::value_error(e.what());
580 }
581
582 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
583 src2_typeid >= td_ns::num_types)
584 {
585 throw std::runtime_error("binary output type lookup failed");
586 }
587 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
588
589 if (dst_typeid < 0) {
590 auto res = py::none();
591 return py::cast<py::object>(res);
592 }
593 else {
594 using type_utils::_dtype_from_typenum;
595
596 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
597 auto dt = _dtype_from_typenum(dst_typenum_t);
598
599 return py::cast<py::object>(dt);
600 }
601}
602
603// ==================== Inplace binary functions =======================
604
605template <typename output_typesT,
606 typename contig_dispatchT,
607 typename strided_dispatchT,
608 typename contig_row_matrix_dispatchT>
609std::pair<sycl::event, sycl::event>
610 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
611 const dpctl::tensor::usm_ndarray &rhs,
612 sycl::queue &exec_q,
613 const std::vector<sycl::event> depends,
614 //
615 const output_typesT &output_type_table,
616 const contig_dispatchT &contig_dispatch_table,
617 const strided_dispatchT &strided_dispatch_table,
618 const contig_row_matrix_dispatchT
619 &contig_row_matrix_broadcast_dispatch_table)
620{
621 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
622
623 // check type_nums
624 int rhs_typenum = rhs.get_typenum();
625 int lhs_typenum = lhs.get_typenum();
626
627 auto array_types = td_ns::usm_ndarray_types();
628 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
629 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
630
631 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
632
633 if (output_typeid != lhs_typeid) {
634 throw py::value_error(
635 "Left-hand side array has unexpected elemental data type.");
636 }
637
638 // check that queues are compatible
639 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
640 throw py::value_error(
641 "Execution queue is not compatible with allocation queues");
642 }
643
644 // check shapes, broadcasting is assumed done by caller
645 // check that dimensions are the same
646 int lhs_nd = lhs.get_ndim();
647 if (lhs_nd != rhs.get_ndim()) {
648 throw py::value_error("Array dimensions are not the same.");
649 }
650
651 // check that shapes are the same
652 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
653 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
654 bool shapes_equal(true);
655 size_t rhs_nelems(1);
656
657 for (int i = 0; i < lhs_nd; ++i) {
658 rhs_nelems *= static_cast<size_t>(rhs_shape[i]);
659 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
660 }
661 if (!shapes_equal) {
662 throw py::value_error("Array shapes are not the same.");
663 }
664
665 // if nelems is zero, return
666 if (rhs_nelems == 0) {
667 return std::make_pair(sycl::event(), sycl::event());
668 }
669
670 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
671
672 // check memory overlap
673 auto const &same_logical_tensors =
674 dpctl::tensor::overlap::SameLogicalTensors();
675 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
676 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
677 throw py::value_error("Arrays index overlapping segments of memory");
678 }
679 // check memory overlap
680 const char *rhs_data = rhs.get_data();
681 char *lhs_data = lhs.get_data();
682
683 // handle contiguous inputs
684 bool is_rhs_c_contig = rhs.is_c_contiguous();
685 bool is_rhs_f_contig = rhs.is_f_contiguous();
686
687 bool is_lhs_c_contig = lhs.is_c_contiguous();
688 bool is_lhs_f_contig = lhs.is_f_contiguous();
689
690 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
691 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
692
693 // dispatch for contiguous inputs
694 if (both_c_contig || both_f_contig) {
695 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
696
697 if (contig_fn != nullptr) {
698 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
699 0, depends);
700 sycl::event ht_ev =
701 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
702
703 return std::make_pair(ht_ev, comp_ev);
704 }
705 }
706
707 // simplify strides
708 auto const &rhs_strides = rhs.get_strides_vector();
709 auto const &lhs_strides = lhs.get_strides_vector();
710
711 using shT = std::vector<py::ssize_t>;
712 shT simplified_shape;
713 shT simplified_rhs_strides;
714 shT simplified_lhs_strides;
715 py::ssize_t rhs_offset(0);
716 py::ssize_t lhs_offset(0);
717
718 int nd = lhs_nd;
719 const py::ssize_t *shape = rhs_shape;
720
721 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
722 // outputs
723 simplified_shape, simplified_rhs_strides,
724 simplified_lhs_strides, rhs_offset, lhs_offset);
725
726 std::vector<sycl::event> host_tasks{};
727 if (nd < 3) {
728 static constexpr auto unit_stride =
729 std::initializer_list<py::ssize_t>{1};
730
731 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
732 isEqual(simplified_lhs_strides, unit_stride))
733 {
734 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
735
736 if (contig_fn != nullptr) {
737 auto comp_ev =
738 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
739 lhs_data, lhs_offset, depends);
740 sycl::event ht_ev = dpctl::utils::keep_args_alive(
741 exec_q, {rhs, lhs}, {comp_ev});
742
743 return std::make_pair(ht_ev, comp_ev);
744 }
745 }
746 if (nd == 2) {
747 static constexpr auto one_zero_strides =
748 std::initializer_list<py::ssize_t>{1, 0};
749 constexpr py::ssize_t one{1};
750 // special case of C-contiguous matrix and a row
751 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
752 isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
753 {
754 auto row_matrix_broadcast_fn =
755 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
756 [lhs_typeid];
757 if (row_matrix_broadcast_fn != nullptr) {
758 size_t n0 = simplified_shape[1];
759 size_t n1 = simplified_shape[0];
760 sycl::event comp_ev = row_matrix_broadcast_fn(
761 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
762 lhs_data, lhs_offset, depends);
763
764 return std::make_pair(dpctl::utils::keep_args_alive(
765 exec_q, {lhs, rhs}, host_tasks),
766 comp_ev);
767 }
768 }
769 }
770 }
771
772 // dispatch to strided code
773 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
774
775 if (strided_fn == nullptr) {
776 throw std::runtime_error(
777 "Strided implementation is missing for rhs_typeid=" +
778 std::to_string(rhs_typeid) +
779 " and lhs_typeid=" + std::to_string(lhs_typeid));
780 }
781
782 using dpctl::tensor::offset_utils::device_allocate_and_pack;
783 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
784 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
785 simplified_lhs_strides);
786 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
787 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
788
789 const py::ssize_t *shape_strides = shape_strides_owner.get();
790
791 sycl::event strided_fn_ev =
792 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
793 lhs_data, lhs_offset, depends, {copy_shape_ev});
794
795 // async free of shape_strides temporary
796 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
797 exec_q, {strided_fn_ev}, shape_strides_owner);
798
799 host_tasks.push_back(tmp_cleanup_ev);
800
801 return std::make_pair(
802 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
803 strided_fn_ev);
804}
805
806} // namespace dpnp::extensions::py_internal