DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
elementwise_functions.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <exception>
33#include <stdexcept>
34#include <utility>
35#include <vector>
36
37#include <sycl/sycl.hpp>
38
39#include "dpctl4pybind11.hpp"
40#include <pybind11/numpy.h>
41#include <pybind11/pybind11.h>
42#include <pybind11/stl.h>
43
44#include "elementwise_functions_type_utils.hpp"
45#include "simplify_iteration_space.hpp"
46
47// dpctl tensor headers
48#include "kernels/alignment.hpp"
49#include "utils/memory_overlap.hpp"
50#include "utils/offset_utils.hpp"
51#include "utils/output_validation.hpp"
52#include "utils/sycl_alloc_utils.hpp"
53#include "utils/type_dispatch.hpp"
54
55static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
56
57namespace dpnp::extensions::py_internal
58{
59namespace py = pybind11;
60namespace td_ns = dpctl::tensor::type_dispatch;
61
62using dpctl::tensor::kernels::alignment_utils::is_aligned;
63using dpctl::tensor::kernels::alignment_utils::required_alignment;
64
65using type_utils::_result_typeid;
66
68template <typename output_typesT,
69 typename contig_dispatchT,
70 typename strided_dispatchT>
71std::pair<sycl::event, sycl::event>
72 py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
73 const dpctl::tensor::usm_ndarray &dst,
74 sycl::queue &q,
75 const std::vector<sycl::event> &depends,
76 //
77 const output_typesT &output_type_vec,
78 const contig_dispatchT &contig_dispatch_vector,
79 const strided_dispatchT &strided_dispatch_vector)
80{
81 int src_typenum = src.get_typenum();
82 int dst_typenum = dst.get_typenum();
83
84 const auto &array_types = td_ns::usm_ndarray_types();
85 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
86 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
87
88 int func_output_typeid = output_type_vec[src_typeid];
89
90 // check that types are supported
91 if (dst_typeid != func_output_typeid) {
92 throw py::value_error(
93 "Destination array has unexpected elemental data type.");
94 }
95
96 // check that queues are compatible
97 if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
98 throw py::value_error(
99 "Execution queue is not compatible with allocation queues");
100 }
101
102 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
103
104 // check that dimensions are the same
105 int src_nd = src.get_ndim();
106 if (src_nd != dst.get_ndim()) {
107 throw py::value_error("Array dimensions are not the same.");
108 }
109
110 // check that shapes are the same
111 const py::ssize_t *src_shape = src.get_shape_raw();
112 const py::ssize_t *dst_shape = dst.get_shape_raw();
113 bool shapes_equal(true);
114 std::size_t src_nelems(1);
115
116 for (int i = 0; i < src_nd; ++i) {
117 src_nelems *= static_cast<std::size_t>(src_shape[i]);
118 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
119 }
120 if (!shapes_equal) {
121 throw py::value_error("Array shapes are not the same.");
122 }
123
124 // if nelems is zero, return
125 if (src_nelems == 0) {
126 return std::make_pair(sycl::event(), sycl::event());
127 }
128
129 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
130
131 // check memory overlap
132 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
133 auto const &same_logical_tensors =
134 dpctl::tensor::overlap::SameLogicalTensors();
135 if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
136 throw py::value_error("Arrays index overlapping segments of memory");
137 }
138
139 const char *src_data = src.get_data();
140 char *dst_data = dst.get_data();
141
142 // handle contiguous inputs
143 bool is_src_c_contig = src.is_c_contiguous();
144 bool is_src_f_contig = src.is_f_contiguous();
145
146 bool is_dst_c_contig = dst.is_c_contiguous();
147 bool is_dst_f_contig = dst.is_f_contiguous();
148
149 bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
150 bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
151
152 if (both_c_contig || both_f_contig) {
153 auto contig_fn = contig_dispatch_vector[src_typeid];
154
155 if (contig_fn == nullptr) {
156 throw std::runtime_error(
157 "Contiguous implementation is missing for src_typeid=" +
158 std::to_string(src_typeid));
159 }
160
161 auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
162 sycl::event ht_ev =
163 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
164
165 return std::make_pair(ht_ev, comp_ev);
166 }
167
168 // simplify iteration space
169 // if 1d with strides 1 - input is contig
170 // dispatch to strided
171
172 auto const &src_strides = src.get_strides_vector();
173 auto const &dst_strides = dst.get_strides_vector();
174
175 using shT = std::vector<py::ssize_t>;
176 shT simplified_shape;
177 shT simplified_src_strides;
178 shT simplified_dst_strides;
179 py::ssize_t src_offset(0);
180 py::ssize_t dst_offset(0);
181
182 int nd = src_nd;
183 const py::ssize_t *shape = src_shape;
184
185 simplify_iteration_space(nd, shape, src_strides, dst_strides,
186 // output
187 simplified_shape, simplified_src_strides,
188 simplified_dst_strides, src_offset, dst_offset);
189
190 if (nd == 1 && simplified_src_strides[0] == 1 &&
191 simplified_dst_strides[0] == 1) {
192 // Special case of contiguous data
193 auto contig_fn = contig_dispatch_vector[src_typeid];
194
195 if (contig_fn == nullptr) {
196 throw std::runtime_error(
197 "Contiguous implementation is missing for src_typeid=" +
198 std::to_string(src_typeid));
199 }
200
201 int src_elem_size = src.get_elemsize();
202 int dst_elem_size = dst.get_elemsize();
203 auto comp_ev =
204 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
205 dst_data + dst_elem_size * dst_offset, depends);
206
207 sycl::event ht_ev =
208 dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
209
210 return std::make_pair(ht_ev, comp_ev);
211 }
212
213 // Strided implementation
214 auto strided_fn = strided_dispatch_vector[src_typeid];
215
216 if (strided_fn == nullptr) {
217 throw std::runtime_error(
218 "Strided implementation is missing for src_typeid=" +
219 std::to_string(src_typeid));
220 }
221
222 using dpctl::tensor::offset_utils::device_allocate_and_pack;
223
224 std::vector<sycl::event> host_tasks{};
225 host_tasks.reserve(2);
226
227 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
228 q, host_tasks, simplified_shape, simplified_src_strides,
229 simplified_dst_strides);
230 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
231 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
232 const py::ssize_t *shape_strides = shape_strides_owner.get();
233
234 sycl::event strided_fn_ev =
235 strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
236 dst_data, dst_offset, depends, {copy_shape_ev});
237
238 // async free of shape_strides temporary
239 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
240 q, {strided_fn_ev}, shape_strides_owner);
241
242 host_tasks.push_back(tmp_cleanup_ev);
243
244 return std::make_pair(
245 dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
246 strided_fn_ev);
247}
248
251template <typename output_typesT>
252py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
253 const output_typesT &output_types)
254{
255 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
256 int src_typeid = -1;
257
258 auto array_types = td_ns::usm_ndarray_types();
259
260 try {
261 src_typeid = array_types.typenum_to_lookup_id(tn);
262 } catch (const std::exception &e) {
263 throw py::value_error(e.what());
264 }
265
266 int dst_typeid = _result_typeid(src_typeid, output_types);
267 if (dst_typeid < 0) {
268 auto res = py::none();
269 return py::cast<py::object>(res);
270 }
271 else {
272 using type_utils::_dtype_from_typenum;
273
274 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
275 auto dt = _dtype_from_typenum(dst_typenum_t);
276
277 return py::cast<py::object>(dt);
278 }
279}
280
285template <typename output_typesT,
286 typename contig_dispatchT,
287 typename strided_dispatchT>
288std::pair<sycl::event, sycl::event>
289 py_unary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src,
290 const dpctl::tensor::usm_ndarray &dst1,
291 const dpctl::tensor::usm_ndarray &dst2,
292 sycl::queue &q,
293 const std::vector<sycl::event> &depends,
294 //
295 const output_typesT &output_type_vec,
296 const contig_dispatchT &contig_dispatch_vector,
297 const strided_dispatchT &strided_dispatch_vector)
298{
299 int src_typenum = src.get_typenum();
300 int dst1_typenum = dst1.get_typenum();
301 int dst2_typenum = dst2.get_typenum();
302
303 const auto &array_types = td_ns::usm_ndarray_types();
304 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
305 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
306 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
307
308 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
309
310 // check that types are supported
311 if (dst1_typeid != func_output_typeids.first ||
312 dst2_typeid != func_output_typeids.second)
313 {
314 throw py::value_error(
315 "One of destination arrays has unexpected elemental data type.");
316 }
317
318 // check that queues are compatible
319 if (!dpctl::utils::queues_are_compatible(q, {src, dst1, dst2})) {
320 throw py::value_error(
321 "Execution queue is not compatible with allocation queues");
322 }
323
324 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
325 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
326
327 // check that dimensions are the same
328 int src_nd = src.get_ndim();
329 if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
330 throw py::value_error("Array dimensions are not the same.");
331 }
332
333 // check that shapes are the same
334 const py::ssize_t *src_shape = src.get_shape_raw();
335 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
336 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
337 bool shapes_equal(true);
338 std::size_t src_nelems(1);
339
340 for (int i = 0; i < src_nd; ++i) {
341 src_nelems *= static_cast<std::size_t>(src_shape[i]);
342 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
343 (src_shape[i] == dst2_shape[i]);
344 }
345 if (!shapes_equal) {
346 throw py::value_error("Array shapes are not the same.");
347 }
348
349 // if nelems is zero, return
350 if (src_nelems == 0) {
351 return std::make_pair(sycl::event(), sycl::event());
352 }
353
354 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
355 src_nelems);
356 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
357 src_nelems);
358
359 // check memory overlap
360 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
361 auto const &same_logical_tensors =
362 dpctl::tensor::overlap::SameLogicalTensors();
363 if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
364 (overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
365 (overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2)))
366 {
367 throw py::value_error("Arrays index overlapping segments of memory");
368 }
369
370 const char *src_data = src.get_data();
371 char *dst1_data = dst1.get_data();
372 char *dst2_data = dst2.get_data();
373
374 // handle contiguous inputs
375 bool is_src_c_contig = src.is_c_contiguous();
376 bool is_src_f_contig = src.is_f_contiguous();
377
378 bool is_dst1_c_contig = dst1.is_c_contiguous();
379 bool is_dst1_f_contig = dst1.is_f_contiguous();
380
381 bool is_dst2_c_contig = dst2.is_c_contiguous();
382 bool is_dst2_f_contig = dst2.is_f_contiguous();
383
384 bool all_c_contig =
385 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
386 bool all_f_contig =
387 (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
388
389 if (all_c_contig || all_f_contig) {
390 auto contig_fn = contig_dispatch_vector[src_typeid];
391
392 if (contig_fn == nullptr) {
393 throw std::runtime_error(
394 "Contiguous implementation is missing for src_typeid=" +
395 std::to_string(src_typeid));
396 }
397
398 auto comp_ev =
399 contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
400 sycl::event ht_ev =
401 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
402
403 return std::make_pair(ht_ev, comp_ev);
404 }
405
406 // simplify iteration space
407 // if 1d with strides 1 - input is contig
408 // dispatch to strided
409
410 auto const &src_strides = src.get_strides_vector();
411 auto const &dst1_strides = dst1.get_strides_vector();
412 auto const &dst2_strides = dst2.get_strides_vector();
413
414 using shT = std::vector<py::ssize_t>;
415 shT simplified_shape;
416 shT simplified_src_strides;
417 shT simplified_dst1_strides;
418 shT simplified_dst2_strides;
419 py::ssize_t src_offset(0);
420 py::ssize_t dst1_offset(0);
421 py::ssize_t dst2_offset(0);
422
423 int nd = src_nd;
424 const py::ssize_t *shape = src_shape;
425
426 simplify_iteration_space_3(
427 nd, shape, src_strides, dst1_strides, dst2_strides,
428 // output
429 simplified_shape, simplified_src_strides, simplified_dst1_strides,
430 simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
431
432 if (nd == 1 && simplified_src_strides[0] == 1 &&
433 simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1)
434 {
435 // Special case of contiguous data
436 auto contig_fn = contig_dispatch_vector[src_typeid];
437
438 if (contig_fn == nullptr) {
439 throw std::runtime_error(
440 "Contiguous implementation is missing for src_typeid=" +
441 std::to_string(src_typeid));
442 }
443
444 int src_elem_size = src.get_elemsize();
445 int dst1_elem_size = dst1.get_elemsize();
446 int dst2_elem_size = dst2.get_elemsize();
447 auto comp_ev =
448 contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
449 dst1_data + dst1_elem_size * dst1_offset,
450 dst2_data + dst2_elem_size * dst2_offset, depends);
451
452 sycl::event ht_ev =
453 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
454
455 return std::make_pair(ht_ev, comp_ev);
456 }
457
458 // Strided implementation
459 auto strided_fn = strided_dispatch_vector[src_typeid];
460
461 if (strided_fn == nullptr) {
462 throw std::runtime_error(
463 "Strided implementation is missing for src_typeid=" +
464 std::to_string(src_typeid));
465 }
466
467 using dpctl::tensor::offset_utils::device_allocate_and_pack;
468
469 std::vector<sycl::event> host_tasks{};
470 host_tasks.reserve(2);
471
472 auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
473 q, host_tasks, simplified_shape, simplified_src_strides,
474 simplified_dst1_strides, simplified_dst2_strides);
475 auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
476 const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
477 const py::ssize_t *shape_strides = shape_strides_owner.get();
478
479 sycl::event strided_fn_ev = strided_fn(
480 q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
481 dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
482
483 // async free of shape_strides temporary
484 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
485 q, {strided_fn_ev}, shape_strides_owner);
486
487 host_tasks.push_back(tmp_cleanup_ev);
488
489 return std::make_pair(
490 dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
491 strided_fn_ev);
492}
493
498template <typename output_typesT>
499std::pair<py::object, py::object>
500 py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
501 const output_typesT &output_types)
502{
503 int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
504 int src_typeid = -1;
505
506 auto array_types = td_ns::usm_ndarray_types();
507
508 try {
509 src_typeid = array_types.typenum_to_lookup_id(tn);
510 } catch (const std::exception &e) {
511 throw py::value_error(e.what());
512 }
513
514 std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
515 int dst1_typeid = dst_typeids.first;
516 int dst2_typeid = dst_typeids.second;
517
518 if (dst1_typeid < 0 || dst2_typeid < 0) {
519 auto res = py::none();
520 auto py_res = py::cast<py::object>(res);
521 return std::make_pair(py_res, py_res);
522 }
523 else {
524 using type_utils::_dtype_from_typenum;
525
526 auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
527 auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
528 auto dt1 = _dtype_from_typenum(dst1_typenum_t);
529 auto dt2 = _dtype_from_typenum(dst2_typenum_t);
530
531 return std::make_pair(py::cast<py::object>(dt1),
532 py::cast<py::object>(dt2));
533 }
534}
535
536// ======================== Binary functions ===========================
537
538namespace
539{
540template <class Container, class T>
541bool isEqual(Container const &c, std::initializer_list<T> const &l)
542{
543 return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
544}
545} // namespace
546
549template <typename output_typesT,
550 typename contig_dispatchT,
551 typename strided_dispatchT,
552 typename contig_matrix_row_dispatchT,
553 typename contig_row_matrix_dispatchT>
554std::pair<sycl::event, sycl::event> py_binary_ufunc(
555 const dpctl::tensor::usm_ndarray &src1,
556 const dpctl::tensor::usm_ndarray &src2,
557 const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
558 sycl::queue &exec_q,
559 const std::vector<sycl::event> depends,
560 //
561 const output_typesT &output_type_table,
562 const contig_dispatchT &contig_dispatch_table,
563 const strided_dispatchT &strided_dispatch_table,
564 const contig_matrix_row_dispatchT
565 &contig_matrix_row_broadcast_dispatch_table,
566 const contig_row_matrix_dispatchT
567 &contig_row_matrix_broadcast_dispatch_table)
568{
569 // check type_nums
570 int src1_typenum = src1.get_typenum();
571 int src2_typenum = src2.get_typenum();
572 int dst_typenum = dst.get_typenum();
573
574 auto array_types = td_ns::usm_ndarray_types();
575 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
576 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
577 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
578
579 int output_typeid = output_type_table[src1_typeid][src2_typeid];
580
581 if (output_typeid != dst_typeid) {
582 throw py::value_error(
583 "Destination array has unexpected elemental data type.");
584 }
585
586 // check that queues are compatible
587 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
588 throw py::value_error(
589 "Execution queue is not compatible with allocation queues");
590 }
591
592 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
593
594 // check shapes, broadcasting is assumed done by caller
595 // check that dimensions are the same
596 int dst_nd = dst.get_ndim();
597 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
598 throw py::value_error("Array dimensions are not the same.");
599 }
600
601 // check that shapes are the same
602 const py::ssize_t *src1_shape = src1.get_shape_raw();
603 const py::ssize_t *src2_shape = src2.get_shape_raw();
604 const py::ssize_t *dst_shape = dst.get_shape_raw();
605 bool shapes_equal(true);
606 std::size_t src_nelems(1);
607
608 for (int i = 0; i < dst_nd; ++i) {
609 src_nelems *= static_cast<std::size_t>(src1_shape[i]);
610 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
611 src2_shape[i] == dst_shape[i]);
612 }
613 if (!shapes_equal) {
614 throw py::value_error("Array shapes are not the same.");
615 }
616
617 // if nelems is zero, return
618 if (src_nelems == 0) {
619 return std::make_pair(sycl::event(), sycl::event());
620 }
621
622 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
623
624 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
625 auto const &same_logical_tensors =
626 dpctl::tensor::overlap::SameLogicalTensors();
627 if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
628 (overlap(src2, dst) && !same_logical_tensors(src2, dst)))
629 {
630 throw py::value_error("Arrays index overlapping segments of memory");
631 }
632 // check memory overlap
633 const char *src1_data = src1.get_data();
634 const char *src2_data = src2.get_data();
635 char *dst_data = dst.get_data();
636
637 // handle contiguous inputs
638 bool is_src1_c_contig = src1.is_c_contiguous();
639 bool is_src1_f_contig = src1.is_f_contiguous();
640
641 bool is_src2_c_contig = src2.is_c_contiguous();
642 bool is_src2_f_contig = src2.is_f_contiguous();
643
644 bool is_dst_c_contig = dst.is_c_contiguous();
645 bool is_dst_f_contig = dst.is_f_contiguous();
646
647 bool all_c_contig =
648 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
649 bool all_f_contig =
650 (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
651
652 // dispatch for contiguous inputs
653 if (all_c_contig || all_f_contig) {
654 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
655
656 if (contig_fn != nullptr) {
657 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
658 src2_data, 0, dst_data, 0, depends);
659 sycl::event ht_ev = dpctl::utils::keep_args_alive(
660 exec_q, {src1, src2, dst}, {comp_ev});
661
662 return std::make_pair(ht_ev, comp_ev);
663 }
664 }
665
666 // simplify strides
667 auto const &src1_strides = src1.get_strides_vector();
668 auto const &src2_strides = src2.get_strides_vector();
669 auto const &dst_strides = dst.get_strides_vector();
670
671 using shT = std::vector<py::ssize_t>;
672 shT simplified_shape;
673 shT simplified_src1_strides;
674 shT simplified_src2_strides;
675 shT simplified_dst_strides;
676 py::ssize_t src1_offset(0);
677 py::ssize_t src2_offset(0);
678 py::ssize_t dst_offset(0);
679
680 int nd = dst_nd;
681 const py::ssize_t *shape = src1_shape;
682
683 simplify_iteration_space_3(
684 nd, shape, src1_strides, src2_strides, dst_strides,
685 // outputs
686 simplified_shape, simplified_src1_strides, simplified_src2_strides,
687 simplified_dst_strides, src1_offset, src2_offset, dst_offset);
688
689 std::vector<sycl::event> host_tasks{};
690 if (nd < 3) {
691 static constexpr auto unit_stride =
692 std::initializer_list<py::ssize_t>{1};
693
694 if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
695 isEqual(simplified_src2_strides, unit_stride) &&
696 isEqual(simplified_dst_strides, unit_stride))
697 {
698 auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
699
700 if (contig_fn != nullptr) {
701 auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
702 src1_offset, src2_data, src2_offset,
703 dst_data, dst_offset, depends);
704 sycl::event ht_ev = dpctl::utils::keep_args_alive(
705 exec_q, {src1, src2, dst}, {comp_ev});
706
707 return std::make_pair(ht_ev, comp_ev);
708 }
709 }
710 if (nd == 2) {
711 static constexpr auto zero_one_strides =
712 std::initializer_list<py::ssize_t>{0, 1};
713 static constexpr auto one_zero_strides =
714 std::initializer_list<py::ssize_t>{1, 0};
715 static constexpr py::ssize_t one{1};
716 // special case of C-contiguous matrix and a row
717 if (isEqual(simplified_src2_strides, zero_one_strides) &&
718 isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
719 isEqual(simplified_dst_strides, {simplified_shape[1], one}))
720 {
721 auto matrix_row_broadcast_fn =
722 contig_matrix_row_broadcast_dispatch_table[src1_typeid]
723 [src2_typeid];
724 if (matrix_row_broadcast_fn != nullptr) {
725 int src1_itemsize = src1.get_elemsize();
726 int src2_itemsize = src2.get_elemsize();
727 int dst_itemsize = dst.get_elemsize();
728
729 if (is_aligned<required_alignment>(
730 src1_data + src1_offset * src1_itemsize) &&
731 is_aligned<required_alignment>(
732 src2_data + src2_offset * src2_itemsize) &&
733 is_aligned<required_alignment>(
734 dst_data + dst_offset * dst_itemsize))
735 {
736 std::size_t n0 = simplified_shape[0];
737 std::size_t n1 = simplified_shape[1];
738 sycl::event comp_ev = matrix_row_broadcast_fn(
739 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
740 src2_data, src2_offset, dst_data, dst_offset,
741 depends);
742
743 return std::make_pair(
744 dpctl::utils::keep_args_alive(
745 exec_q, {src1, src2, dst}, host_tasks),
746 comp_ev);
747 }
748 }
749 }
750 if (isEqual(simplified_src1_strides, one_zero_strides) &&
751 isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
752 isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
753 {
754 auto row_matrix_broadcast_fn =
755 contig_row_matrix_broadcast_dispatch_table[src1_typeid]
756 [src2_typeid];
757 if (row_matrix_broadcast_fn != nullptr) {
758
759 int src1_itemsize = src1.get_elemsize();
760 int src2_itemsize = src2.get_elemsize();
761 int dst_itemsize = dst.get_elemsize();
762
763 if (is_aligned<required_alignment>(
764 src1_data + src1_offset * src1_itemsize) &&
765 is_aligned<required_alignment>(
766 src2_data + src2_offset * src2_itemsize) &&
767 is_aligned<required_alignment>(
768 dst_data + dst_offset * dst_itemsize))
769 {
770 std::size_t n0 = simplified_shape[1];
771 std::size_t n1 = simplified_shape[0];
772 sycl::event comp_ev = row_matrix_broadcast_fn(
773 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
774 src2_data, src2_offset, dst_data, dst_offset,
775 depends);
776
777 return std::make_pair(
778 dpctl::utils::keep_args_alive(
779 exec_q, {src1, src2, dst}, host_tasks),
780 comp_ev);
781 }
782 }
783 }
784 }
785 }
786
787 // dispatch to strided code
788 auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
789
790 if (strided_fn == nullptr) {
791 throw std::runtime_error(
792 "Strided implementation is missing for src1_typeid=" +
793 std::to_string(src1_typeid) +
794 " and src2_typeid=" + std::to_string(src2_typeid));
795 }
796
797 using dpctl::tensor::offset_utils::device_allocate_and_pack;
798 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
799 exec_q, host_tasks, simplified_shape, simplified_src1_strides,
800 simplified_src2_strides, simplified_dst_strides);
801 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
802 auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
803
804 const py::ssize_t *shape_strides = shape_strides_owner.get();
805
806 sycl::event strided_fn_ev = strided_fn(
807 exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
808 src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
809
810 // async free of shape_strides temporary
811 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
812 exec_q, {strided_fn_ev}, shape_strides_owner);
813
814 host_tasks.push_back(tmp_cleanup_ev);
815
816 return std::make_pair(
817 dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
818 strided_fn_ev);
819}
820
822template <typename output_typesT>
823py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
824 const py::dtype &input2_dtype,
825 const output_typesT &output_types_table)
826{
827 int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl
828 int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl
829 int src1_typeid = -1;
830 int src2_typeid = -1;
831
832 auto array_types = td_ns::usm_ndarray_types();
833
834 try {
835 src1_typeid = array_types.typenum_to_lookup_id(tn1);
836 src2_typeid = array_types.typenum_to_lookup_id(tn2);
837 } catch (const std::exception &e) {
838 throw py::value_error(e.what());
839 }
840
841 if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
842 src2_typeid >= td_ns::num_types)
843 {
844 throw std::runtime_error("binary output type lookup failed");
845 }
846 int dst_typeid = output_types_table[src1_typeid][src2_typeid];
847
848 if (dst_typeid < 0) {
849 auto res = py::none();
850 return py::cast<py::object>(res);
851 }
852 else {
853 using type_utils::_dtype_from_typenum;
854
855 auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
856 auto dt = _dtype_from_typenum(dst_typenum_t);
857
858 return py::cast<py::object>(dt);
859 }
860}
861
862// ==================== Inplace binary functions =======================
863
864template <typename output_typesT,
865 typename contig_dispatchT,
866 typename strided_dispatchT,
867 typename contig_row_matrix_dispatchT>
868std::pair<sycl::event, sycl::event>
869 py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
870 const dpctl::tensor::usm_ndarray &rhs,
871 sycl::queue &exec_q,
872 const std::vector<sycl::event> depends,
873 //
874 const output_typesT &output_type_table,
875 const contig_dispatchT &contig_dispatch_table,
876 const strided_dispatchT &strided_dispatch_table,
877 const contig_row_matrix_dispatchT
878 &contig_row_matrix_broadcast_dispatch_table)
879{
880 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
881
882 // check type_nums
883 int rhs_typenum = rhs.get_typenum();
884 int lhs_typenum = lhs.get_typenum();
885
886 auto array_types = td_ns::usm_ndarray_types();
887 int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
888 int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
889
890 int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
891
892 if (output_typeid != lhs_typeid) {
893 throw py::value_error(
894 "Left-hand side array has unexpected elemental data type.");
895 }
896
897 // check that queues are compatible
898 if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
899 throw py::value_error(
900 "Execution queue is not compatible with allocation queues");
901 }
902
903 // check shapes, broadcasting is assumed done by caller
904 // check that dimensions are the same
905 int lhs_nd = lhs.get_ndim();
906 if (lhs_nd != rhs.get_ndim()) {
907 throw py::value_error("Array dimensions are not the same.");
908 }
909
910 // check that shapes are the same
911 const py::ssize_t *rhs_shape = rhs.get_shape_raw();
912 const py::ssize_t *lhs_shape = lhs.get_shape_raw();
913 bool shapes_equal(true);
914 std::size_t rhs_nelems(1);
915
916 for (int i = 0; i < lhs_nd; ++i) {
917 rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
918 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
919 }
920 if (!shapes_equal) {
921 throw py::value_error("Array shapes are not the same.");
922 }
923
924 // if nelems is zero, return
925 if (rhs_nelems == 0) {
926 return std::make_pair(sycl::event(), sycl::event());
927 }
928
929 dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
930
931 // check memory overlap
932 auto const &same_logical_tensors =
933 dpctl::tensor::overlap::SameLogicalTensors();
934 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
935 if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
936 throw py::value_error("Arrays index overlapping segments of memory");
937 }
938 // check memory overlap
939 const char *rhs_data = rhs.get_data();
940 char *lhs_data = lhs.get_data();
941
942 // handle contiguous inputs
943 bool is_rhs_c_contig = rhs.is_c_contiguous();
944 bool is_rhs_f_contig = rhs.is_f_contiguous();
945
946 bool is_lhs_c_contig = lhs.is_c_contiguous();
947 bool is_lhs_f_contig = lhs.is_f_contiguous();
948
949 bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
950 bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
951
952 // dispatch for contiguous inputs
953 if (both_c_contig || both_f_contig) {
954 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
955
956 if (contig_fn != nullptr) {
957 auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
958 0, depends);
959 sycl::event ht_ev =
960 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
961
962 return std::make_pair(ht_ev, comp_ev);
963 }
964 }
965
966 // simplify strides
967 auto const &rhs_strides = rhs.get_strides_vector();
968 auto const &lhs_strides = lhs.get_strides_vector();
969
970 using shT = std::vector<py::ssize_t>;
971 shT simplified_shape;
972 shT simplified_rhs_strides;
973 shT simplified_lhs_strides;
974 py::ssize_t rhs_offset(0);
975 py::ssize_t lhs_offset(0);
976
977 int nd = lhs_nd;
978 const py::ssize_t *shape = rhs_shape;
979
980 simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
981 // outputs
982 simplified_shape, simplified_rhs_strides,
983 simplified_lhs_strides, rhs_offset, lhs_offset);
984
985 std::vector<sycl::event> host_tasks{};
986 if (nd < 3) {
987 static constexpr auto unit_stride =
988 std::initializer_list<py::ssize_t>{1};
989
990 if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
991 isEqual(simplified_lhs_strides, unit_stride))
992 {
993 auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
994
995 if (contig_fn != nullptr) {
996 auto comp_ev =
997 contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
998 lhs_data, lhs_offset, depends);
999 sycl::event ht_ev = dpctl::utils::keep_args_alive(
1000 exec_q, {rhs, lhs}, {comp_ev});
1001
1002 return std::make_pair(ht_ev, comp_ev);
1003 }
1004 }
1005 if (nd == 2) {
1006 static constexpr auto one_zero_strides =
1007 std::initializer_list<py::ssize_t>{1, 0};
1008 static constexpr py::ssize_t one{1};
1009 // special case of C-contiguous matrix and a row
1010 if (isEqual(simplified_rhs_strides, one_zero_strides) &&
1011 isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
1012 {
1013 auto row_matrix_broadcast_fn =
1014 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
1015 [lhs_typeid];
1016 if (row_matrix_broadcast_fn != nullptr) {
1017 std::size_t n0 = simplified_shape[1];
1018 std::size_t n1 = simplified_shape[0];
1019 sycl::event comp_ev = row_matrix_broadcast_fn(
1020 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
1021 lhs_data, lhs_offset, depends);
1022
1023 return std::make_pair(dpctl::utils::keep_args_alive(
1024 exec_q, {lhs, rhs}, host_tasks),
1025 comp_ev);
1026 }
1027 }
1028 }
1029 }
1030
1031 // dispatch to strided code
1032 auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
1033
1034 if (strided_fn == nullptr) {
1035 throw std::runtime_error(
1036 "Strided implementation is missing for rhs_typeid=" +
1037 std::to_string(rhs_typeid) +
1038 " and lhs_typeid=" + std::to_string(lhs_typeid));
1039 }
1040
1041 using dpctl::tensor::offset_utils::device_allocate_and_pack;
1042 auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
1043 exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
1044 simplified_lhs_strides);
1045 auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
1046 auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
1047
1048 const py::ssize_t *shape_strides = shape_strides_owner.get();
1049
1050 sycl::event strided_fn_ev =
1051 strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
1052 lhs_data, lhs_offset, depends, {copy_shape_ev});
1053
1054 // async free of shape_strides temporary
1055 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1056 exec_q, {strided_fn_ev}, shape_strides_owner);
1057
1058 host_tasks.push_back(tmp_cleanup_ev);
1059
1060 return std::make_pair(
1061 dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
1062 strided_fn_ev);
1063}
1064} // namespace dpnp::extensions::py_internal