DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <exception>
32#include <stdexcept>
33
34#include <sycl/sycl.hpp>
35
36#include "dpctl4pybind11.hpp"
37#include <pybind11/numpy.h>
38#include <pybind11/pybind11.h>
39#include <pybind11/stl.h>
40
41#include "elementwise_functions_type_utils.hpp"
42#include "simplify_iteration_space.hpp"
43
44// dpctl tensor headers
45#include "kernels/alignment.hpp"
46// #include "kernels/dpctl_tensor_types.hpp"
47#include "utils/memory_overlap.hpp"
48#include "utils/offset_utils.hpp"
49#include "utils/output_validation.hpp"
50#include "utils/sycl_alloc_utils.hpp"
51#include "utils/type_dispatch.hpp"
52
53namespace py = pybind11;
54namespace td_ns = dpctl::tensor::type_dispatch;
55
56static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
57
58namespace dpnp::extensions::py_internal
59{
60
61using dpctl::tensor::kernels::alignment_utils::is_aligned;
62using dpctl::tensor::kernels::alignment_utils::required_alignment;
63
65template <typename output_typesT,
66 typename contig_dispatchT,
67 typename strided_dispatchT>
68std::pair<sycl::event, sycl::event>
69 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
70 const dpctl::tensor::usm_ndarray &dst,
71 sycl::queue &q,
72 const std::vector<sycl::event> &depends,
73 //
74 const output_typesT &output_type_vec,
75 const contig_dispatchT &contig_dispatch_vector,
76 const strided_dispatchT &strided_dispatch_vector)
77{
78 int src_typenum = src.get_typenum();
79 int dst_typenum = dst.get_typenum();
80
81 const auto &array_types = td_ns::usm_ndarray_types();
82 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
83 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
84
85 int func_output_typeid = output_type_vec[src_typeid];
86
87 // check that types are supported
88 if (dst_typeid != func_output_typeid) {
89 throw py::value_error(
90 "Destination array has unexpected elemental data type.");
91 }
92
93 // check that queues are compatible
94 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
95 throw py::value_error(
96 "Execution queue is not compatible with allocation queues");
97 }
98
99 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
100
101 // check that dimensions are the same
102 int src_nd = src.get_ndim();
103 if (src_nd != dst.get_ndim()) {
104 throw py::value_error("Array dimensions are not the same.");
105 }
106
107 // check that shapes are the same
108 const py::ssize_t *src_shape = src.get_shape_raw();
109 const py::ssize_t *dst_shape = dst.get_shape_raw();
110 bool shapes_equal(true);
111 size_t src_nelems(1);
112
113 for (int i = 0; i < src_nd; ++i) {
114 src_nelems *= static_cast<size_t>(src_shape[i]);
115 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
116 }
117 if (!shapes_equal) {
118 throw py::value_error("Array shapes are not the same.");
119 }
120
121 // if nelems is zero, return
122 if (src_nelems == 0) {
123 return std::make_pair(sycl::event(), sycl::event());
124 }
125
126 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
127
128 // check memory overlap
129 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
130 auto const &same_logical_tensors =
131 dpctl::tensor::overlap::SameLogicalTensors();
132 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
133 throw py::value_error("Arrays index overlapping segments of memory");
134 }
135
136 const char *src_data = src.get_data();
137 char *dst_data = dst.get_data();
138
139 // handle contiguous inputs
140 bool is_src_c_contig = src.is_c_contiguous();
141 bool is_src_f_contig = src.is_f_contiguous();
142
143 bool is_dst_c_contig = dst.is_c_contiguous();
144 bool is_dst_f_contig = dst.is_f_contiguous();
145
146 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
147 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
148
149 if (both_c_contig || both_f_contig) {
150 auto contig_fn = contig_dispatch_vector[src_typeid];
151
152 if (contig_fn == nullptr) {
153 throw std::runtime_error(
154 "Contiguous implementation is missing for src_typeid=" +
155 std::to_string(src_typeid));
156 }
157
158 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
159 sycl::event ht_ev =
160 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
161
162 return std::make_pair(ht_ev, comp_ev);
163 }
164
165 // simplify iteration space
166 // if 1d with strides 1 - input is contig
167 // dispatch to strided
168
169 auto const &src_strides = src.get_strides_vector();
170 auto const &dst_strides = dst.get_strides_vector();
171
172 using shT = std::vector<py::ssize_t>;
173 shT simplified_shape;
174 shT simplified_src_strides;
175 shT simplified_dst_strides;
176 py::ssize_t src_offset(0);
177 py::ssize_t dst_offset(0);
178
179 int nd = src_nd;
180 const py::ssize_t *shape = src_shape;
181
182 simplify_iteration_space(nd, shape, src_strides, dst_strides,
183 // output
184 simplified_shape, simplified_src_strides,
185 simplified_dst_strides, src_offset, dst_offset);
186
187 if (nd == 1 && simplified_src_strides[0] == 1 &&
188 simplified_dst_strides[0] == 1) {
189 // Special case of contiguous data
190 auto contig_fn = contig_dispatch_vector[src_typeid];
191
192 if (contig_fn == nullptr) {
193 throw std::runtime_error(
194 "Contiguous implementation is missing for src_typeid=" +
195 std::to_string(src_typeid));
196 }
197
198 int src_elem_size = src.get_elemsize();
199 int dst_elem_size = dst.get_elemsize();
200 auto comp_ev =
201 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
202 dst_data + dst_elem_size * dst_offset, depends);
203
204 sycl::event ht_ev =
205 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
206
207 return std::make_pair(ht_ev, comp_ev);
208 }
209
210 // Strided implementation
211 auto strided_fn = strided_dispatch_vector[src_typeid];
212
213 if (strided_fn == nullptr) {
214 throw std::runtime_error(
215 "Strided implementation is missing for src_typeid=" +
216 std::to_string(src_typeid));
217 }
218
219 using dpctl::tensor::offset_utils::device_allocate_and_pack;
220
221 std::vector<sycl::event> host_tasks{};
222 host_tasks.reserve(2);
223
224 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
225 q, host_tasks, simplified_shape, simplified_src_strides,
226 simplified_dst_strides);
227 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
228 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
229 const py::ssize_t *shape_strides = shape_strides_owner.get();
230
231 sycl::event strided_fn_ev =
232 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
233 dst_data, dst_offset, depends, {copy_shape_ev});
234
235 // async free of shape_strides temporary
236 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
237 q, {strided_fn_ev}, shape_strides_owner);
238
239 host_tasks.push_back(tmp_cleanup_ev);
240
241 return std::make_pair(
242 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
243 strided_fn_ev);
244}
245
248template <typename output_typesT>
249py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
250 const output_typesT &output_types)
251{
252 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
253 int src_typeid = -1;
254
255 auto array_types = td_ns::usm_ndarray_types();
256
257 try {
258 src_typeid = array_types.typenum_to_lookup_id(tn);
259 } catch (const std::exception &e) {
260 throw py::value_error(e.what());
261 }
262
263 using type_utils::_result_typeid;
264 int dst_typeid = _result_typeid(src_typeid, output_types);
265
266 if (dst_typeid < 0) {
267 auto res = py::none();
268 return py::cast<py::object>(res);
269 }
270 else {
271 using type_utils::_dtype_from_typenum;
272
273 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
274 auto dt = _dtype_from_typenum(dst_typenum_t);
275
276 return py::cast<py::object>(dt);
277 }
278}
279
280// ======================== Binary functions ===========================
281
282namespace
283{
284template <class Container, class T>
285bool isEqual(Container const &c, std::initializer_list<T> const &l)
286{
287 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
288}
289} // namespace
290
293template <typename output_typesT,
294 typename contig_dispatchT,
295 typename strided_dispatchT,
296 typename contig_matrix_row_dispatchT,
297 typename contig_row_matrix_dispatchT>
298std::pair<sycl::event, sycl::event> py_binary_ufunc(
299 const dpctl::tensor::usm_ndarray &src1,
300 const dpctl::tensor::usm_ndarray &src2,
301 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
302 sycl::queue &exec_q,
303 const std::vector<sycl::event> depends,
304 //
305 const output_typesT &output_type_table,
306 const contig_dispatchT &contig_dispatch_table,
307 const strided_dispatchT &strided_dispatch_table,
308 const contig_matrix_row_dispatchT
309 &contig_matrix_row_broadcast_dispatch_table,
310 const contig_row_matrix_dispatchT
311 &contig_row_matrix_broadcast_dispatch_table)
312{
313 // check type_nums
314 int src1_typenum = src1.get_typenum();
315 int src2_typenum = src2.get_typenum();
316 int dst_typenum = dst.get_typenum();
317
318 auto array_types = td_ns::usm_ndarray_types();
319 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
320 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
321 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
322
323 int output_typeid = output_type_table[src1_typeid][src2_typeid];
324
325 if (output_typeid != dst_typeid) {
326 throw py::value_error(
327 "Destination array has unexpected elemental data type.");
328 }
329
330 // check that queues are compatible
331 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
332 throw py::value_error(
333 "Execution queue is not compatible with allocation queues");
334 }
335
336 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
337
338 // check shapes, broadcasting is assumed done by caller
339 // check that dimensions are the same
340 int dst_nd = dst.get_ndim();
341 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
342 throw py::value_error("Array dimensions are not the same.");
343 }
344
345 // check that shapes are the same
346 const py::ssize_t *src1_shape = src1.get_shape_raw();
347 const py::ssize_t *src2_shape = src2.get_shape_raw();
348 const py::ssize_t *dst_shape = dst.get_shape_raw();
349 bool shapes_equal(true);
350 size_t src_nelems(1);
351
352 for (int i = 0; i < dst_nd; ++i) {
353 src_nelems *= static_cast<size_t>(src1_shape[i]);
354 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
355 src2_shape[i] == dst_shape[i]);
356 }
357 if (!shapes_equal) {
358 throw py::value_error("Array shapes are not the same.");
359 }
360
361 // if nelems is zero, return
362 if (src_nelems == 0) {
363 return std::make_pair(sycl::event(), sycl::event());
364 }
365
366 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
367
368 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
369 auto const &same_logical_tensors =
370 dpctl::tensor::overlap::SameLogicalTensors();
371 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
372 (overlap(src2, dst) && !same_logical_tensors(src2, dst)))
373 {
374 throw py::value_error("Arrays index overlapping segments of memory");
375 }
376 // check memory overlap
377 const char *src1_data = src1.get_data();
378 const char *src2_data = src2.get_data();
379 char *dst_data = dst.get_data();
380
381 // handle contiguous inputs
382 bool is_src1_c_contig = src1.is_c_contiguous();
383 bool is_src1_f_contig = src1.is_f_contiguous();
384
385 bool is_src2_c_contig = src2.is_c_contiguous();
386 bool is_src2_f_contig = src2.is_f_contiguous();
387
388 bool is_dst_c_contig = dst.is_c_contiguous();
389 bool is_dst_f_contig = dst.is_f_contiguous();
390
391 bool all_c_contig =
392 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
393 bool all_f_contig =
394 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
395
396 // dispatch for contiguous inputs
397 if (all_c_contig || all_f_contig) {
398 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
399
400 if (contig_fn != nullptr) {
401 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
402 src2_data, 0, dst_data, 0, depends);
403 sycl::event ht_ev = dpctl::utils::keep_args_alive(
404 exec_q, {src1, src2, dst}, {comp_ev});
405
406 return std::make_pair(ht_ev, comp_ev);
407 }
408 }
409
410 // simplify strides
411 auto const &src1_strides = src1.get_strides_vector();
412 auto const &src2_strides = src2.get_strides_vector();
413 auto const &dst_strides = dst.get_strides_vector();
414
415 using shT = std::vector<py::ssize_t>;
416 shT simplified_shape;
417 shT simplified_src1_strides;
418 shT simplified_src2_strides;
419 shT simplified_dst_strides;
420 py::ssize_t src1_offset(0);
421 py::ssize_t src2_offset(0);
422 py::ssize_t dst_offset(0);
423
424 int nd = dst_nd;
425 const py::ssize_t *shape = src1_shape;
426
427 simplify_iteration_space_3(
428 nd, shape, src1_strides, src2_strides, dst_strides,
429 // outputs
430 simplified_shape, simplified_src1_strides, simplified_src2_strides,
431 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
432
433 std::vector<sycl::event> host_tasks{};
434 if (nd < 3) {
435 static constexpr auto unit_stride =
436 std::initializer_list<py::ssize_t>{1};
437
438 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
439 isEqual(simplified_src2_strides, unit_stride) &&
440 isEqual(simplified_dst_strides, unit_stride))
441 {
442 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
443
444 if (contig_fn != nullptr) {
445 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
446 src1_offset, src2_data, src2_offset,
447 dst_data, dst_offset, depends);
448 sycl::event ht_ev = dpctl::utils::keep_args_alive(
449 exec_q, {src1, src2, dst}, {comp_ev});
450
451 return std::make_pair(ht_ev, comp_ev);
452 }
453 }
454 if (nd == 2) {
455 static constexpr auto zero_one_strides =
456 std::initializer_list<py::ssize_t>{0, 1};
457 static constexpr auto one_zero_strides =
458 std::initializer_list<py::ssize_t>{1, 0};
459 constexpr py::ssize_t one{1};
460 // special case of C-contiguous matrix and a row
461 if (isEqual(simplified_src2_strides, zero_one_strides) &&
462 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
463 isEqual(simplified_dst_strides, {simplified_shape[1], one}))
464 {
465 auto matrix_row_broadcast_fn =
466 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
467 [src2_typeid];
468 if (matrix_row_broadcast_fn != nullptr) {
469 int src1_itemsize = src1.get_elemsize();
470 int src2_itemsize = src2.get_elemsize();
471 int dst_itemsize = dst.get_elemsize();
472
473 if (is_aligned<required_alignment>(
474 src1_data + src1_offset * src1_itemsize) &&
475 is_aligned<required_alignment>(
476 src2_data + src2_offset * src2_itemsize) &&
477 is_aligned<required_alignment>(
478 dst_data + dst_offset * dst_itemsize))
479 {
480 size_t n0 = simplified_shape[0];
481 size_t n1 = simplified_shape[1];
482 sycl::event comp_ev = matrix_row_broadcast_fn(
483 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
484 src2_data, src2_offset, dst_data, dst_offset,
485 depends);
486
487 return std::make_pair(
488 dpctl::utils::keep_args_alive(
489 exec_q, {src1, src2, dst}, host_tasks),
490 comp_ev);
491 }
492 }
493 }
494 if (isEqual(simplified_src1_strides, one_zero_strides) &&
495 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
496 isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
497 {
498 auto row_matrix_broadcast_fn =
499 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
500 [src2_typeid];
501 if (row_matrix_broadcast_fn != nullptr) {
502
503 int src1_itemsize = src1.get_elemsize();
504 int src2_itemsize = src2.get_elemsize();
505 int dst_itemsize = dst.get_elemsize();
506
507 if (is_aligned<required_alignment>(
508 src1_data + src1_offset * src1_itemsize) &&
509 is_aligned<required_alignment>(
510 src2_data + src2_offset * src2_itemsize) &&
511 is_aligned<required_alignment>(
512 dst_data + dst_offset * dst_itemsize))
513 {
514 size_t n0 = simplified_shape[1];
515 size_t n1 = simplified_shape[0];
516 sycl::event comp_ev = row_matrix_broadcast_fn(
517 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
518 src2_data, src2_offset, dst_data, dst_offset,
519 depends);
520
521 return std::make_pair(
522 dpctl::utils::keep_args_alive(
523 exec_q, {src1, src2, dst}, host_tasks),
524 comp_ev);
525 }
526 }
527 }
528 }
529 }
530
531 // dispatch to strided code
532 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
533
534 if (strided_fn == nullptr) {
535 throw std::runtime_error(
536 "Strided implementation is missing for src1_typeid=" +
537 std::to_string(src1_typeid) +
538 " and src2_typeid=" + std::to_string(src2_typeid));
539 }
540
541 using dpctl::tensor::offset_utils::device_allocate_and_pack;
542 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
543 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
544 simplified_src2_strides, simplified_dst_strides);
545 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
546 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
547
548 const py::ssize_t *shape_strides = shape_strides_owner.get();
549
550 sycl::event strided_fn_ev = strided_fn(
551 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
552 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
553
554 // async free of shape_strides temporary
555 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
556 exec_q, {strided_fn_ev}, shape_strides_owner);
557
558 host_tasks.push_back(tmp_cleanup_ev);
559
560 return std::make_pair(
561 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
562 strided_fn_ev);
563}
564
566template <typename output_typesT>
567py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
568 const py::dtype &input2_dtype,
569 const output_typesT &output_types_table)
570{
571 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
572 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
573 int src1_typeid = -1;
574 int src2_typeid = -1;
575
576 auto array_types = td_ns::usm_ndarray_types();
577
578 try {
579 src1_typeid = array_types.typenum_to_lookup_id(tn1);
580 src2_typeid = array_types.typenum_to_lookup_id(tn2);
581 } catch (const std::exception &e) {
582 throw py::value_error(e.what());
583 }
584
585 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
586 src2_typeid >= td_ns::num_types)
587 {
588 throw std::runtime_error("binary output type lookup failed");
589 }
590 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
591
592 if (dst_typeid < 0) {
593 auto res = py::none();
594 return py::cast<py::object>(res);
595 }
596 else {
597 using type_utils::_dtype_from_typenum;
598
599 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
600 auto dt = _dtype_from_typenum(dst_typenum_t);
601
602 return py::cast<py::object>(dt);
603 }
604}
605
606// ==================== Inplace binary functions =======================
607
608template <typename output_typesT,
609 typename contig_dispatchT,
610 typename strided_dispatchT,
611 typename contig_row_matrix_dispatchT>
612std::pair<sycl::event, sycl::event>
613 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
614 const dpctl::tensor::usm_ndarray &rhs,
615 sycl::queue &exec_q,
616 const std::vector<sycl::event> depends,
617 //
618 const output_typesT &output_type_table,
619 const contig_dispatchT &contig_dispatch_table,
620 const strided_dispatchT &strided_dispatch_table,
621 const contig_row_matrix_dispatchT
622 &contig_row_matrix_broadcast_dispatch_table)
623{
624 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
625
626 // check type_nums
627 int rhs_typenum = rhs.get_typenum();
628 int lhs_typenum = lhs.get_typenum();
629
630 auto array_types = td_ns::usm_ndarray_types();
631 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
632 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
633
634 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
635
636 if (output_typeid != lhs_typeid) {
637 throw py::value_error(
638 "Left-hand side array has unexpected elemental data type.");
639 }
640
641 // check that queues are compatible
642 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
643 throw py::value_error(
644 "Execution queue is not compatible with allocation queues");
645 }
646
647 // check shapes, broadcasting is assumed done by caller
648 // check that dimensions are the same
649 int lhs_nd = lhs.get_ndim();
650 if (lhs_nd != rhs.get_ndim()) {
651 throw py::value_error("Array dimensions are not the same.");
652 }
653
654 // check that shapes are the same
655 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
656 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
657 bool shapes_equal(true);
658 size_t rhs_nelems(1);
659
660 for (int i = 0; i < lhs_nd; ++i) {
661 rhs_nelems *= static_cast<size_t>(rhs_shape[i]);
662 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
663 }
664 if (!shapes_equal) {
665 throw py::value_error("Array shapes are not the same.");
666 }
667
668 // if nelems is zero, return
669 if (rhs_nelems == 0) {
670 return std::make_pair(sycl::event(), sycl::event());
671 }
672
673 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
674
675 // check memory overlap
676 auto const &same_logical_tensors =
677 dpctl::tensor::overlap::SameLogicalTensors();
678 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
679 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
680 throw py::value_error("Arrays index overlapping segments of memory");
681 }
682 // check memory overlap
683 const char *rhs_data = rhs.get_data();
684 char *lhs_data = lhs.get_data();
685
686 // handle contiguous inputs
687 bool is_rhs_c_contig = rhs.is_c_contiguous();
688 bool is_rhs_f_contig = rhs.is_f_contiguous();
689
690 bool is_lhs_c_contig = lhs.is_c_contiguous();
691 bool is_lhs_f_contig = lhs.is_f_contiguous();
692
693 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
694 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
695
696 // dispatch for contiguous inputs
697 if (both_c_contig || both_f_contig) {
698 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
699
700 if (contig_fn != nullptr) {
701 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
702 0, depends);
703 sycl::event ht_ev =
704 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
705
706 return std::make_pair(ht_ev, comp_ev);
707 }
708 }
709
710 // simplify strides
711 auto const &rhs_strides = rhs.get_strides_vector();
712 auto const &lhs_strides = lhs.get_strides_vector();
713
714 using shT = std::vector<py::ssize_t>;
715 shT simplified_shape;
716 shT simplified_rhs_strides;
717 shT simplified_lhs_strides;
718 py::ssize_t rhs_offset(0);
719 py::ssize_t lhs_offset(0);
720
721 int nd = lhs_nd;
722 const py::ssize_t *shape = rhs_shape;
723
724 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
725 // outputs
726 simplified_shape, simplified_rhs_strides,
727 simplified_lhs_strides, rhs_offset, lhs_offset);
728
729 std::vector<sycl::event> host_tasks{};
730 if (nd < 3) {
731 static constexpr auto unit_stride =
732 std::initializer_list<py::ssize_t>{1};
733
734 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
735 isEqual(simplified_lhs_strides, unit_stride))
736 {
737 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
738
739 if (contig_fn != nullptr) {
740 auto comp_ev =
741 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
742 lhs_data, lhs_offset, depends);
743 sycl::event ht_ev = dpctl::utils::keep_args_alive(
744 exec_q, {rhs, lhs}, {comp_ev});
745
746 return std::make_pair(ht_ev, comp_ev);
747 }
748 }
749 if (nd == 2) {
750 static constexpr auto one_zero_strides =
751 std::initializer_list<py::ssize_t>{1, 0};
752 constexpr py::ssize_t one{1};
753 // special case of C-contiguous matrix and a row
754 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
755 isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
756 {
757 auto row_matrix_broadcast_fn =
758 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
759 [lhs_typeid];
760 if (row_matrix_broadcast_fn != nullptr) {
761 size_t n0 = simplified_shape[1];
762 size_t n1 = simplified_shape[0];
763 sycl::event comp_ev = row_matrix_broadcast_fn(
764 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
765 lhs_data, lhs_offset, depends);
766
767 return std::make_pair(dpctl::utils::keep_args_alive(
768 exec_q, {lhs, rhs}, host_tasks),
769 comp_ev);
770 }
771 }
772 }
773 }
774
775 // dispatch to strided code
776 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
777
778 if (strided_fn == nullptr) {
779 throw std::runtime_error(
780 "Strided implementation is missing for rhs_typeid=" +
781 std::to_string(rhs_typeid) +
782 " and lhs_typeid=" + std::to_string(lhs_typeid));
783 }
784
785 using dpctl::tensor::offset_utils::device_allocate_and_pack;
786 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
787 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
788 simplified_lhs_strides);
789 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
790 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
791
792 const py::ssize_t *shape_strides = shape_strides_owner.get();
793
794 sycl::event strided_fn_ev =
795 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
796 lhs_data, lhs_offset, depends, {copy_shape_ev});
797
798 // async free of shape_strides temporary
799 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
800 exec_q, {strided_fn_ev}, shape_strides_owner);
801
802 host_tasks.push_back(tmp_cleanup_ev);
803
804 return std::make_pair(
805 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
806 strided_fn_ev);
807}
808
809} // namespace dpnp::extensions::py_internal