DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include for dpctl_capi struct and casters
32#include "dpctl4pybind11.hpp"
33
34// Include generated Cython headers for usm_ndarray
35// (struct definition and constants only)
36#include "dpnp/tensor/_usmarray.h"
37#include "dpnp/tensor/_usmarray_api.h"
38
39#include <array>
40#include <cassert>
41#include <cstddef> // for std::size_t for C++ linkage
42#include <cstdint>
43#include <memory>
44#include <stdexcept>
45#include <utility>
46#include <vector>
47
48#include <pybind11/pybind11.h>
49
50#include <sycl/sycl.hpp>
51
52namespace py = pybind11;
53
54namespace dpnp
55{
56namespace detail
57{
58// Lookup a type according to its size, and return a value corresponding to the
59// NumPy typenum.
60
61template <typename Concrete>
62constexpr int platform_typeid_lookup()
63{
64 return -1;
65}
66
67template <typename Concrete, typename T, typename... Ts, typename... Ints>
68constexpr int platform_typeid_lookup(int I, Ints... Is)
69{
70 return sizeof(Concrete) == sizeof(T)
71 ? I
72 : platform_typeid_lookup<Concrete, Ts...>(Is...);
73}
74
76{
77public:
78 PyTypeObject *PyUSMArrayType_;
79
80 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
81 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
82 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
83 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
84 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
85 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
86 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
87 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
88 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
89 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
90 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
91 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
92 const py::ssize_t *,
93 int,
94 Py_MemoryObject *,
95 py::ssize_t,
96 char);
97 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
98 int,
99 DPCTLSyclUSMRef,
100 DPCTLSyclQueueRef,
101 PyObject *);
102 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
103 const py::ssize_t *,
104 int,
105 const py::ssize_t *,
106 DPCTLSyclUSMRef,
107 DPCTLSyclQueueRef,
108 py::ssize_t,
109 PyObject *);
110
111 int USM_ARRAY_C_CONTIGUOUS_;
112 int USM_ARRAY_F_CONTIGUOUS_;
113 int USM_ARRAY_WRITABLE_;
114 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
115 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
116 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
117 UAR_HALF_;
118 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
119 UAR_INT64_, UAR_UINT64_;
120
121 ~dpnp_capi() { default_usm_ndarray_.reset(); };
122
123 static auto &get()
124 {
125 static dpnp_capi api{};
126 return api;
127 }
128
129 py::object default_usm_ndarray_pyobj() { return *default_usm_ndarray_; }
130
131private:
132 struct Deleter
133 {
134 void operator()(py::object *p) const
135 {
136 const bool initialized = Py_IsInitialized();
137#if PY_VERSION_HEX < 0x30d0000
138 const bool finalizing = _Py_IsFinalizing();
139#else
140 const bool finalizing = Py_IsFinalizing();
141#endif
142 const bool guard = initialized && !finalizing;
143
144 if (guard) {
145 delete p;
146 }
147 }
148 };
149
150 std::shared_ptr<py::object> default_usm_ndarray_;
151
152 dpnp_capi()
153 : PyUSMArrayType_(nullptr), UsmNDArray_GetData_(nullptr),
154 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
155 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
156 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
157 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
158 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
159 UsmNDArray_MakeSimpleFromMemory_(nullptr),
160 UsmNDArray_MakeSimpleFromPtr_(nullptr),
161 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
162 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
163 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
164 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
165 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
166 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
167 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
168 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
169 UAR_INT64_(-1), UAR_UINT64_(-1), default_usm_ndarray_{}
170
171 {
172 // Import dpnp tensor module for PyUSMArrayType
173 import_dpnp__tensor___usmarray();
174
175 this->PyUSMArrayType_ = &PyUSMArrayType;
176
177 // dpnp.tensor.usm_ndarray API
178 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
179 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
180 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
181 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
182 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
183 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
184 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
185 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
186 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
187 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
188 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
189 this->UsmNDArray_MakeSimpleFromMemory_ =
190 UsmNDArray_MakeSimpleFromMemory;
191 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
192 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
193
194 // constants
195 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
196 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
197 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
198 this->UAR_BOOL_ = UAR_BOOL;
199 this->UAR_BYTE_ = UAR_BYTE;
200 this->UAR_UBYTE_ = UAR_UBYTE;
201 this->UAR_SHORT_ = UAR_SHORT;
202 this->UAR_USHORT_ = UAR_USHORT;
203 this->UAR_INT_ = UAR_INT;
204 this->UAR_UINT_ = UAR_UINT;
205 this->UAR_LONG_ = UAR_LONG;
206 this->UAR_ULONG_ = UAR_ULONG;
207 this->UAR_LONGLONG_ = UAR_LONGLONG;
208 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
209 this->UAR_FLOAT_ = UAR_FLOAT;
210 this->UAR_DOUBLE_ = UAR_DOUBLE;
211 this->UAR_CFLOAT_ = UAR_CFLOAT;
212 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
213 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
214 this->UAR_HALF_ = UAR_HALF;
215
216 // deduced disjoint types
217 this->UAR_INT8_ = UAR_BYTE;
218 this->UAR_UINT8_ = UAR_UBYTE;
219 this->UAR_INT16_ = UAR_SHORT;
220 this->UAR_UINT16_ = UAR_USHORT;
221 this->UAR_INT32_ =
222 platform_typeid_lookup<std::int32_t, long, int, short>(
223 UAR_LONG, UAR_INT, UAR_SHORT);
224 this->UAR_UINT32_ =
225 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
226 unsigned short>(UAR_ULONG, UAR_UINT,
227 UAR_USHORT);
228 this->UAR_INT64_ =
229 platform_typeid_lookup<std::int64_t, long, long long, int>(
230 UAR_LONG, UAR_LONGLONG, UAR_INT);
231 this->UAR_UINT64_ =
232 platform_typeid_lookup<std::uint64_t, unsigned long,
233 unsigned long long, unsigned int>(
234 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
235
236 py::object py_default_usm_memory =
237 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj();
238
239 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
240 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
241
242 const py::object &py_default_usm_ndarray =
243 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
244 py::arg("buffer") = py_default_usm_memory);
245
246 default_usm_ndarray_ = std::shared_ptr<py::object>(
247 new py::object{py_default_usm_ndarray}, Deleter{});
248 }
249
250 dpnp_capi(dpnp_capi const &) = default;
251 dpnp_capi &operator=(dpnp_capi const &) = default;
252 dpnp_capi &operator=(dpnp_capi &&) = default;
253
254}; // struct dpnp_capi
255} // namespace detail
256
257namespace tensor
258{
259inline std::vector<py::ssize_t>
260 c_contiguous_strides(int nd,
261 const py::ssize_t *shape,
262 py::ssize_t element_size = 1)
263{
264 if (nd > 0) {
265 std::vector<py::ssize_t> c_strides(nd, element_size);
266 for (int ic = nd - 1; ic > 0;) {
267 py::ssize_t next_v = c_strides[ic] * shape[ic];
268 c_strides[--ic] = next_v;
269 }
270 return c_strides;
271 }
272 else {
273 return std::vector<py::ssize_t>();
274 }
275}
276
277inline std::vector<py::ssize_t>
278 f_contiguous_strides(int nd,
279 const py::ssize_t *shape,
280 py::ssize_t element_size = 1)
281{
282 if (nd > 0) {
283 std::vector<py::ssize_t> f_strides(nd, element_size);
284 for (int i = 0; i < nd - 1;) {
285 py::ssize_t next_v = f_strides[i] * shape[i];
286 f_strides[++i] = next_v;
287 }
288 return f_strides;
289 }
290 else {
291 return std::vector<py::ssize_t>();
292 }
293}
294
295inline std::vector<py::ssize_t>
296 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
297 py::ssize_t element_size = 1)
298{
299 return c_contiguous_strides(shape.size(), shape.data(), element_size);
300}
301
302inline std::vector<py::ssize_t>
303 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
304 py::ssize_t element_size = 1)
305{
306 return f_contiguous_strides(shape.size(), shape.data(), element_size);
307}
308
309class usm_ndarray : public py::object
310{
311public:
312 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
313 return PyObject_TypeCheck(
314 o, detail::dpnp_capi::get().PyUSMArrayType_) != 0;
315 })
316
318 : py::object(detail::dpnp_capi::get().default_usm_ndarray_pyobj(),
319 borrowed_t{})
320 {
321 if (!m_ptr)
322 throw py::error_already_set();
323 }
324
325 char *get_data() const
326 {
327 PyUSMArrayObject *raw_ar = usm_array_ptr();
328
329 auto const &api = detail::dpnp_capi::get();
330 return api.UsmNDArray_GetData_(raw_ar);
331 }
332
333 template <typename T>
334 T *get_data() const
335 {
336 return reinterpret_cast<T *>(get_data());
337 }
338
339 int get_ndim() const
340 {
341 PyUSMArrayObject *raw_ar = usm_array_ptr();
342
343 auto const &api = detail::dpnp_capi::get();
344 return api.UsmNDArray_GetNDim_(raw_ar);
345 }
346
347 const py::ssize_t *get_shape_raw() const
348 {
349 PyUSMArrayObject *raw_ar = usm_array_ptr();
350
351 auto const &api = detail::dpnp_capi::get();
352 return api.UsmNDArray_GetShape_(raw_ar);
353 }
354
355 std::vector<py::ssize_t> get_shape_vector() const
356 {
357 auto raw_sh = get_shape_raw();
358 auto nd = get_ndim();
359
360 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
361 return shape_vector;
362 }
363
364 py::ssize_t get_shape(int i) const
365 {
366 auto shape_ptr = get_shape_raw();
367 return shape_ptr[i];
368 }
369
370 const py::ssize_t *get_strides_raw() const
371 {
372 PyUSMArrayObject *raw_ar = usm_array_ptr();
373
374 auto const &api = detail::dpnp_capi::get();
375 return api.UsmNDArray_GetStrides_(raw_ar);
376 }
377
378 std::vector<py::ssize_t> get_strides_vector() const
379 {
380 auto raw_st = get_strides_raw();
381 auto nd = get_ndim();
382
383 if (raw_st == nullptr) {
384 auto is_c_contig = is_c_contiguous();
385 auto is_f_contig = is_f_contiguous();
386 auto raw_sh = get_shape_raw();
387 if (is_c_contig) {
388 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
389 return contig_strides;
390 }
391 else if (is_f_contig) {
392 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
393 return contig_strides;
394 }
395 else {
396 throw std::runtime_error("Invalid array encountered when "
397 "building strides");
398 }
399 }
400 else {
401 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
402 return st_vec;
403 }
404 }
405
406 py::ssize_t get_size() const
407 {
408 PyUSMArrayObject *raw_ar = usm_array_ptr();
409
410 auto const &api = detail::dpnp_capi::get();
411 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
412 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
413
414 py::ssize_t nelems = 1;
415 for (int i = 0; i < ndim; ++i) {
416 nelems *= shape[i];
417 }
418
419 assert(nelems >= 0);
420 return nelems;
421 }
422
423 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
424 {
425 PyUSMArrayObject *raw_ar = usm_array_ptr();
426
427 auto const &api = detail::dpnp_capi::get();
428 int nd = api.UsmNDArray_GetNDim_(raw_ar);
429 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
430 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
431
432 py::ssize_t offset_min = 0;
433 py::ssize_t offset_max = 0;
434 if (strides == nullptr) {
435 py::ssize_t stride(1);
436 for (int i = 0; i < nd; ++i) {
437 offset_max += stride * (shape[i] - 1);
438 stride *= shape[i];
439 }
440 }
441 else {
442 for (int i = 0; i < nd; ++i) {
443 py::ssize_t delta = strides[i] * (shape[i] - 1);
444 if (strides[i] > 0) {
445 offset_max += delta;
446 }
447 else {
448 offset_min += delta;
449 }
450 }
451 }
452 return std::make_pair(offset_min, offset_max);
453 }
454
455 sycl::queue get_queue() const
456 {
457 PyUSMArrayObject *raw_ar = usm_array_ptr();
458
459 auto const &api = detail::dpnp_capi::get();
460 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
461 return *(reinterpret_cast<sycl::queue *>(QRef));
462 }
463
464 sycl::device get_device() const
465 {
466 PyUSMArrayObject *raw_ar = usm_array_ptr();
467
468 auto const &api = detail::dpnp_capi::get();
469 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
470 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
471 }
472
473 int get_typenum() const
474 {
475 PyUSMArrayObject *raw_ar = usm_array_ptr();
476
477 auto const &api = detail::dpnp_capi::get();
478 return api.UsmNDArray_GetTypenum_(raw_ar);
479 }
480
481 int get_flags() const
482 {
483 PyUSMArrayObject *raw_ar = usm_array_ptr();
484
485 auto const &api = detail::dpnp_capi::get();
486 return api.UsmNDArray_GetFlags_(raw_ar);
487 }
488
489 int get_elemsize() const
490 {
491 PyUSMArrayObject *raw_ar = usm_array_ptr();
492
493 auto const &api = detail::dpnp_capi::get();
494 return api.UsmNDArray_GetElementSize_(raw_ar);
495 }
496
497 bool is_c_contiguous() const
498 {
499 int flags = get_flags();
500 auto const &api = detail::dpnp_capi::get();
501 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
502 }
503
504 bool is_f_contiguous() const
505 {
506 int flags = get_flags();
507 auto const &api = detail::dpnp_capi::get();
508 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
509 }
510
511 bool is_writable() const
512 {
513 int flags = get_flags();
514 auto const &api = detail::dpnp_capi::get();
515 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
516 }
517
519 py::object get_usm_data() const
520 {
521 PyUSMArrayObject *raw_ar = usm_array_ptr();
522
523 auto const &api = detail::dpnp_capi::get();
524 // base_ is the Memory object - return new reference
525 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
526
527 // pass reference ownership to py::object
528 return py::reinterpret_steal<py::object>(usm_data);
529 }
530
531 bool is_managed_by_smart_ptr() const
532 {
533 PyUSMArrayObject *raw_ar = usm_array_ptr();
534
535 auto const &api = detail::dpnp_capi::get();
536 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
537
538 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
539 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
540 Py_DECREF(usm_data);
541 return false;
542 }
543
544 Py_MemoryObject *mem_obj =
545 reinterpret_cast<Py_MemoryObject *>(usm_data);
546 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
547
548 Py_DECREF(usm_data);
549 return bool(opaque_ptr);
550 }
551
552 const std::shared_ptr<void> &get_smart_ptr_owner() const
553 {
554 PyUSMArrayObject *raw_ar = usm_array_ptr();
555
556 auto const &api = detail::dpnp_capi::get();
557 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
558
559 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
560 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
561 Py_DECREF(usm_data);
562 throw std::runtime_error(
563 "usm_ndarray object does not have Memory object "
564 "managing lifetime of USM allocation");
565 }
566
567 Py_MemoryObject *mem_obj =
568 reinterpret_cast<Py_MemoryObject *>(usm_data);
569 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
570 Py_DECREF(usm_data);
571
572 if (opaque_ptr) {
573 auto shptr_ptr =
574 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
575 return *shptr_ptr;
576 }
577 else {
578 throw std::runtime_error(
579 "Memory object underlying usm_ndarray does not have "
580 "smart pointer managing lifetime of USM allocation");
581 }
582 }
583
584private:
585 PyUSMArrayObject *usm_array_ptr() const
586 {
587 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
588 }
589};
590} // end namespace tensor
591
592namespace utils
593{
594namespace detail
595{
596// TODO: future version of dpctl will include a more general way of passing
597// shared_ptrs to keep_args_alive, so that future overload can be used here
598// instead of reimplementing keep_args_alive
599
601{
602 // TODO: do we need to check for memory here? Or can we assume only
603 // dpnp::tensor::usm_ndarray will be passed?
604 static bool is_usm_managed_by_shared_ptr(const py::object &h)
605 {
606
607 if (py::isinstance<::dpctl::memory::usm_memory>(h)) {
608 const auto &usm_memory_inst =
609 py::cast<::dpctl::memory::usm_memory>(h);
610 return usm_memory_inst.is_managed_by_smart_ptr();
611 }
612 else if (py::isinstance<tensor::usm_ndarray>(h)) {
613 const auto &usm_array_inst = py::cast<tensor::usm_ndarray>(h);
614 return usm_array_inst.is_managed_by_smart_ptr();
615 }
616
617 return false;
618 }
619
620 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
621 {
622 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
623 const auto &usm_memory_inst =
624 py::cast<dpctl::memory::usm_memory>(h);
625 return usm_memory_inst.get_smart_ptr_owner();
626 }
627 else if (py::isinstance<tensor::usm_ndarray>(h)) {
628 const auto &usm_array_inst = py::cast<tensor::usm_ndarray>(h);
629 return usm_array_inst.get_smart_ptr_owner();
630 }
631
632 throw std::runtime_error(
633 "Attempted extraction of shared_ptr on an unrecognized type");
634 }
635};
636} // end of namespace detail
637
638template <std::size_t num>
639sycl::event keep_args_alive(sycl::queue &q,
640 const py::object (&py_objs)[num],
641 const std::vector<sycl::event> &depends = {})
642{
643 std::size_t n_objects_held = 0;
644 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
645
646 std::size_t n_usm_owners_held = 0;
647 std::array<std::shared_ptr<void>, num> shp_usm{};
648
649 for (std::size_t i = 0; i < num; ++i) {
650 const auto &py_obj_i = py_objs[i];
651 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
652 const auto &shp =
653 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
654 shp_usm[n_usm_owners_held] = shp;
655 ++n_usm_owners_held;
656 }
657 else {
658 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
659 shp_arr[n_objects_held]->inc_ref();
660 ++n_objects_held;
661 }
662 }
663
664 bool use_depends = true;
665 sycl::event host_task_ev;
666
667 if (n_usm_owners_held > 0) {
668 host_task_ev = q.submit([&](sycl::handler &cgh) {
669 if (use_depends) {
670 cgh.depends_on(depends);
671 use_depends = false;
672 }
673 else {
674 cgh.depends_on(host_task_ev);
675 }
676 cgh.host_task([shp_usm = std::move(shp_usm)]() {
677 // no body, but shared pointers are captured in
678 // the lambda, ensuring that USM allocation is
679 // kept alive
680 });
681 });
682 }
683
684 if (n_objects_held > 0) {
685 host_task_ev = q.submit([&](sycl::handler &cgh) {
686 if (use_depends) {
687 cgh.depends_on(depends);
688 use_depends = false;
689 }
690 else {
691 cgh.depends_on(host_task_ev);
692 }
693 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
694 py::gil_scoped_acquire acquire;
695
696 for (std::size_t i = 0; i < n_objects_held; ++i) {
697 shp_arr[i]->dec_ref();
698 }
699 });
700 });
701 }
702
703 return host_task_ev;
704}
705
706// add to namespace for convenience
707using ::dpctl::utils::queues_are_compatible;
708
711template <std::size_t num>
712bool queues_are_compatible(const sycl::queue &exec_q,
713 const tensor::usm_ndarray (&arrs)[num])
714{
715 for (std::size_t i = 0; i < num; ++i) {
716
717 if (exec_q != arrs[i].get_queue()) {
718 return false;
719 }
720 }
721 return true;
722}
723} // end namespace utils
724} // end namespace dpnp
py::object get_usm_data() const
Get usm_data property of array.