DPNP C++ backend kernel library 0.21.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include for dpctl_capi struct and casters
32#include "dpctl4pybind11.hpp"
33
34// Include generated Cython headers for usm_ndarray
35#include "dpnp/tensor/_usmarray.h"
36#include "dpnp/tensor/_usmarray_api.h"
37// Include usm_ndarray constants (flags, type numbers)
38#include "../../tensor/include/usm_ndarray_constants.h"
39
40#include <array>
41#include <cassert>
42#include <cstddef> // for std::size_t for C++ linkage
43#include <cstdint>
44#include <memory>
45#include <stdexcept>
46#include <utility>
47#include <vector>
48
49#include <pybind11/pybind11.h>
50
51#include <sycl/sycl.hpp>
52
53namespace py = pybind11;
54
55namespace dpnp
56{
57namespace detail
58{
59// Lookup a type according to its size, and return a value corresponding to the
60// NumPy typenum.
61
62template <typename Concrete>
63constexpr int platform_typeid_lookup()
64{
65 return -1;
66}
67
68template <typename Concrete, typename T, typename... Ts, typename... Ints>
69constexpr int platform_typeid_lookup(int I, Ints... Is)
70{
71 return sizeof(Concrete) == sizeof(T)
72 ? I
73 : platform_typeid_lookup<Concrete, Ts...>(Is...);
74}
75
77{
78public:
79 PyTypeObject *PyUSMArrayType_;
80
81 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
82 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
83 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
84 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
85 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
86 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
87 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
88 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
89 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
90 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
91 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
92 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
93 const py::ssize_t *,
94 int,
95 Py_MemoryObject *,
96 py::ssize_t,
97 char);
98 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
99 int,
100 DPCTLSyclUSMRef,
101 DPCTLSyclQueueRef,
102 PyObject *);
103 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
104 const py::ssize_t *,
105 int,
106 const py::ssize_t *,
107 DPCTLSyclUSMRef,
108 DPCTLSyclQueueRef,
109 py::ssize_t,
110 PyObject *);
111
112 int USM_ARRAY_C_CONTIGUOUS_;
113 int USM_ARRAY_F_CONTIGUOUS_;
114 int USM_ARRAY_WRITABLE_;
115 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
116 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
117 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
118 UAR_HALF_;
119 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
120 UAR_INT64_, UAR_UINT64_;
121
122 ~dpnp_capi() { default_usm_ndarray_.reset(); };
123
124 static auto &get()
125 {
126 static dpnp_capi api{};
127 return api;
128 }
129
130 py::object default_usm_ndarray_pyobj() { return *default_usm_ndarray_; }
131
132private:
133 struct Deleter
134 {
135 void operator()(py::object *p) const
136 {
137 const bool initialized = Py_IsInitialized();
138#if PY_VERSION_HEX < 0x30d0000
139 const bool finalizing = _Py_IsFinalizing();
140#else
141 const bool finalizing = Py_IsFinalizing();
142#endif
143 const bool guard = initialized && !finalizing;
144
145 if (guard) {
146 delete p;
147 }
148 }
149 };
150
151 std::shared_ptr<py::object> default_usm_ndarray_;
152
153 dpnp_capi()
154 : PyUSMArrayType_(nullptr), UsmNDArray_GetData_(nullptr),
155 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
156 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
157 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
158 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
159 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
160 UsmNDArray_MakeSimpleFromMemory_(nullptr),
161 UsmNDArray_MakeSimpleFromPtr_(nullptr),
162 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
163 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
164 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
165 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
166 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
167 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
168 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
169 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
170 UAR_INT64_(-1), UAR_UINT64_(-1), default_usm_ndarray_{}
171
172 {
173 // Import dpnp tensor module for PyUSMArrayType
174 import_dpnp__tensor___usmarray();
175
176 this->PyUSMArrayType_ = &PyUSMArrayType;
177
178 // dpnp.tensor.usm_ndarray API
179 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
180 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
181 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
182 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
183 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
184 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
185 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
186 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
187 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
188 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
189 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
190 this->UsmNDArray_MakeSimpleFromMemory_ =
191 UsmNDArray_MakeSimpleFromMemory;
192 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
193 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
194
195 // constants from usm_ndarray_constants.h
196 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS_VALUE;
197 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS_VALUE;
198 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE_VALUE;
199 this->UAR_BOOL_ = UAR_BOOL_VALUE;
200 this->UAR_BYTE_ = UAR_BYTE_VALUE;
201 this->UAR_UBYTE_ = UAR_UBYTE_VALUE;
202 this->UAR_SHORT_ = UAR_SHORT_VALUE;
203 this->UAR_USHORT_ = UAR_USHORT_VALUE;
204 this->UAR_INT_ = UAR_INT_VALUE;
205 this->UAR_UINT_ = UAR_UINT_VALUE;
206 this->UAR_LONG_ = UAR_LONG_VALUE;
207 this->UAR_ULONG_ = UAR_ULONG_VALUE;
208 this->UAR_LONGLONG_ = UAR_LONGLONG_VALUE;
209 this->UAR_ULONGLONG_ = UAR_ULONGLONG_VALUE;
210 this->UAR_FLOAT_ = UAR_FLOAT_VALUE;
211 this->UAR_DOUBLE_ = UAR_DOUBLE_VALUE;
212 this->UAR_CFLOAT_ = UAR_CFLOAT_VALUE;
213 this->UAR_CDOUBLE_ = UAR_CDOUBLE_VALUE;
214 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL_VALUE;
215 this->UAR_HALF_ = UAR_HALF_VALUE;
216
217 // deduced disjoint types
218 this->UAR_INT8_ = UAR_BYTE_VALUE;
219 this->UAR_UINT8_ = UAR_UBYTE_VALUE;
220 this->UAR_INT16_ = UAR_SHORT_VALUE;
221 this->UAR_UINT16_ = UAR_USHORT_VALUE;
222 this->UAR_INT32_ =
223 platform_typeid_lookup<std::int32_t, long, int, short>(
224 UAR_LONG_VALUE, UAR_INT_VALUE, UAR_SHORT_VALUE);
225 this->UAR_UINT32_ =
226 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
227 unsigned short>(
228 UAR_ULONG_VALUE, UAR_UINT_VALUE, UAR_USHORT_VALUE);
229 this->UAR_INT64_ =
230 platform_typeid_lookup<std::int64_t, long, long long, int>(
231 UAR_LONG_VALUE, UAR_LONGLONG_VALUE, UAR_INT_VALUE);
232 this->UAR_UINT64_ =
233 platform_typeid_lookup<std::uint64_t, unsigned long,
234 unsigned long long, unsigned int>(
235 UAR_ULONG_VALUE, UAR_ULONGLONG_VALUE, UAR_UINT_VALUE);
236
237 py::object py_default_usm_memory =
238 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj();
239
240 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
241 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
242
243 const py::object &py_default_usm_ndarray =
244 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
245 py::arg("buffer") = py_default_usm_memory);
246
247 default_usm_ndarray_ = std::shared_ptr<py::object>(
248 new py::object{py_default_usm_ndarray}, Deleter{});
249 }
250
251 dpnp_capi(dpnp_capi const &) = default;
252 dpnp_capi &operator=(dpnp_capi const &) = default;
253 dpnp_capi &operator=(dpnp_capi &&) = default;
254
255}; // struct dpnp_capi
256} // namespace detail
257
258namespace tensor
259{
260inline std::vector<py::ssize_t>
261 c_contiguous_strides(int nd,
262 const py::ssize_t *shape,
263 py::ssize_t element_size = 1)
264{
265 if (nd > 0) {
266 std::vector<py::ssize_t> c_strides(nd, element_size);
267 for (int ic = nd - 1; ic > 0;) {
268 py::ssize_t next_v = c_strides[ic] * shape[ic];
269 c_strides[--ic] = next_v;
270 }
271 return c_strides;
272 }
273 else {
274 return std::vector<py::ssize_t>();
275 }
276}
277
278inline std::vector<py::ssize_t>
279 f_contiguous_strides(int nd,
280 const py::ssize_t *shape,
281 py::ssize_t element_size = 1)
282{
283 if (nd > 0) {
284 std::vector<py::ssize_t> f_strides(nd, element_size);
285 for (int i = 0; i < nd - 1;) {
286 py::ssize_t next_v = f_strides[i] * shape[i];
287 f_strides[++i] = next_v;
288 }
289 return f_strides;
290 }
291 else {
292 return std::vector<py::ssize_t>();
293 }
294}
295
296inline std::vector<py::ssize_t>
297 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
298 py::ssize_t element_size = 1)
299{
300 return c_contiguous_strides(shape.size(), shape.data(), element_size);
301}
302
303inline std::vector<py::ssize_t>
304 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
305 py::ssize_t element_size = 1)
306{
307 return f_contiguous_strides(shape.size(), shape.data(), element_size);
308}
309
310class usm_ndarray : public py::object
311{
312public:
313 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
314 return PyObject_TypeCheck(
315 o, detail::dpnp_capi::get().PyUSMArrayType_) != 0;
316 })
317
319 : py::object(detail::dpnp_capi::get().default_usm_ndarray_pyobj(),
320 borrowed_t{})
321 {
322 if (!m_ptr)
323 throw py::error_already_set();
324 }
325
326 char *get_data() const
327 {
328 PyUSMArrayObject *raw_ar = usm_array_ptr();
329
330 auto const &api = detail::dpnp_capi::get();
331 return api.UsmNDArray_GetData_(raw_ar);
332 }
333
334 template <typename T>
335 T *get_data() const
336 {
337 return reinterpret_cast<T *>(get_data());
338 }
339
340 int get_ndim() const
341 {
342 PyUSMArrayObject *raw_ar = usm_array_ptr();
343
344 auto const &api = detail::dpnp_capi::get();
345 return api.UsmNDArray_GetNDim_(raw_ar);
346 }
347
348 const py::ssize_t *get_shape_raw() const
349 {
350 PyUSMArrayObject *raw_ar = usm_array_ptr();
351
352 auto const &api = detail::dpnp_capi::get();
353 return api.UsmNDArray_GetShape_(raw_ar);
354 }
355
356 std::vector<py::ssize_t> get_shape_vector() const
357 {
358 auto raw_sh = get_shape_raw();
359 auto nd = get_ndim();
360
361 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
362 return shape_vector;
363 }
364
365 py::ssize_t get_shape(int i) const
366 {
367 auto shape_ptr = get_shape_raw();
368 return shape_ptr[i];
369 }
370
371 const py::ssize_t *get_strides_raw() const
372 {
373 PyUSMArrayObject *raw_ar = usm_array_ptr();
374
375 auto const &api = detail::dpnp_capi::get();
376 return api.UsmNDArray_GetStrides_(raw_ar);
377 }
378
379 std::vector<py::ssize_t> get_strides_vector() const
380 {
381 auto raw_st = get_strides_raw();
382 auto nd = get_ndim();
383
384 if (raw_st == nullptr) {
385 auto is_c_contig = is_c_contiguous();
386 auto is_f_contig = is_f_contiguous();
387 auto raw_sh = get_shape_raw();
388 if (is_c_contig) {
389 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
390 return contig_strides;
391 }
392 else if (is_f_contig) {
393 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
394 return contig_strides;
395 }
396 else {
397 throw std::runtime_error("Invalid array encountered when "
398 "building strides");
399 }
400 }
401 else {
402 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
403 return st_vec;
404 }
405 }
406
407 py::ssize_t get_size() const
408 {
409 PyUSMArrayObject *raw_ar = usm_array_ptr();
410
411 auto const &api = detail::dpnp_capi::get();
412 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
413 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
414
415 py::ssize_t nelems = 1;
416 for (int i = 0; i < ndim; ++i) {
417 nelems *= shape[i];
418 }
419
420 assert(nelems >= 0);
421 return nelems;
422 }
423
424 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
425 {
426 PyUSMArrayObject *raw_ar = usm_array_ptr();
427
428 auto const &api = detail::dpnp_capi::get();
429 int nd = api.UsmNDArray_GetNDim_(raw_ar);
430 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
431 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
432
433 py::ssize_t offset_min = 0;
434 py::ssize_t offset_max = 0;
435 if (strides == nullptr) {
436 py::ssize_t stride(1);
437 for (int i = 0; i < nd; ++i) {
438 offset_max += stride * (shape[i] - 1);
439 stride *= shape[i];
440 }
441 }
442 else {
443 for (int i = 0; i < nd; ++i) {
444 py::ssize_t delta = strides[i] * (shape[i] - 1);
445 if (strides[i] > 0) {
446 offset_max += delta;
447 }
448 else {
449 offset_min += delta;
450 }
451 }
452 }
453 return std::make_pair(offset_min, offset_max);
454 }
455
456 sycl::queue get_queue() const
457 {
458 PyUSMArrayObject *raw_ar = usm_array_ptr();
459
460 auto const &api = detail::dpnp_capi::get();
461 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
462 return *(reinterpret_cast<sycl::queue *>(QRef));
463 }
464
465 sycl::device get_device() const
466 {
467 PyUSMArrayObject *raw_ar = usm_array_ptr();
468
469 auto const &api = detail::dpnp_capi::get();
470 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
471 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
472 }
473
474 int get_typenum() const
475 {
476 PyUSMArrayObject *raw_ar = usm_array_ptr();
477
478 auto const &api = detail::dpnp_capi::get();
479 return api.UsmNDArray_GetTypenum_(raw_ar);
480 }
481
482 int get_flags() const
483 {
484 PyUSMArrayObject *raw_ar = usm_array_ptr();
485
486 auto const &api = detail::dpnp_capi::get();
487 return api.UsmNDArray_GetFlags_(raw_ar);
488 }
489
490 int get_elemsize() const
491 {
492 PyUSMArrayObject *raw_ar = usm_array_ptr();
493
494 auto const &api = detail::dpnp_capi::get();
495 return api.UsmNDArray_GetElementSize_(raw_ar);
496 }
497
498 bool is_c_contiguous() const
499 {
500 int flags = get_flags();
501 auto const &api = detail::dpnp_capi::get();
502 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
503 }
504
505 bool is_f_contiguous() const
506 {
507 int flags = get_flags();
508 auto const &api = detail::dpnp_capi::get();
509 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
510 }
511
512 bool is_writable() const
513 {
514 int flags = get_flags();
515 auto const &api = detail::dpnp_capi::get();
516 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
517 }
518
520 py::object get_usm_data() const
521 {
522 PyUSMArrayObject *raw_ar = usm_array_ptr();
523
524 auto const &api = detail::dpnp_capi::get();
525 // base_ is the Memory object - return new reference
526 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
527
528 // pass reference ownership to py::object
529 return py::reinterpret_steal<py::object>(usm_data);
530 }
531
532 bool is_managed_by_smart_ptr() const
533 {
534 PyUSMArrayObject *raw_ar = usm_array_ptr();
535
536 auto const &api = detail::dpnp_capi::get();
537 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
538
539 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
540 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
541 Py_DECREF(usm_data);
542 return false;
543 }
544
545 Py_MemoryObject *mem_obj =
546 reinterpret_cast<Py_MemoryObject *>(usm_data);
547 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
548
549 Py_DECREF(usm_data);
550 return bool(opaque_ptr);
551 }
552
553 const std::shared_ptr<void> &get_smart_ptr_owner() const
554 {
555 PyUSMArrayObject *raw_ar = usm_array_ptr();
556
557 auto const &api = detail::dpnp_capi::get();
558 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
559
560 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
561 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
562 Py_DECREF(usm_data);
563 throw std::runtime_error(
564 "usm_ndarray object does not have Memory object "
565 "managing lifetime of USM allocation");
566 }
567
568 Py_MemoryObject *mem_obj =
569 reinterpret_cast<Py_MemoryObject *>(usm_data);
570 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
571 Py_DECREF(usm_data);
572
573 if (opaque_ptr) {
574 auto shptr_ptr =
575 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
576 return *shptr_ptr;
577 }
578 else {
579 throw std::runtime_error(
580 "Memory object underlying usm_ndarray does not have "
581 "smart pointer managing lifetime of USM allocation");
582 }
583 }
584
585private:
586 PyUSMArrayObject *usm_array_ptr() const
587 {
588 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
589 }
590};
591} // end namespace tensor
592
593namespace utils
594{
595namespace detail
596{
597// TODO: future version of dpctl will include a more general way of passing
598// shared_ptrs to keep_args_alive, so that future overload can be used here
599// instead of reimplementing keep_args_alive
600
602{
603 // TODO: do we need to check for memory here? Or can we assume only
604 // dpnp::tensor::usm_ndarray will be passed?
605 static bool is_usm_managed_by_shared_ptr(const py::object &h)
606 {
607
608 if (py::isinstance<::dpctl::memory::usm_memory>(h)) {
609 const auto &usm_memory_inst =
610 py::cast<::dpctl::memory::usm_memory>(h);
611 return usm_memory_inst.is_managed_by_smart_ptr();
612 }
613 else if (py::isinstance<tensor::usm_ndarray>(h)) {
614 const auto &usm_array_inst = py::cast<tensor::usm_ndarray>(h);
615 return usm_array_inst.is_managed_by_smart_ptr();
616 }
617
618 return false;
619 }
620
621 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
622 {
623 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
624 const auto &usm_memory_inst =
625 py::cast<dpctl::memory::usm_memory>(h);
626 return usm_memory_inst.get_smart_ptr_owner();
627 }
628 else if (py::isinstance<tensor::usm_ndarray>(h)) {
629 const auto &usm_array_inst = py::cast<tensor::usm_ndarray>(h);
630 return usm_array_inst.get_smart_ptr_owner();
631 }
632
633 throw std::runtime_error(
634 "Attempted extraction of shared_ptr on an unrecognized type");
635 }
636};
637} // end of namespace detail
638
639template <std::size_t num>
640sycl::event keep_args_alive(sycl::queue &q,
641 const py::object (&py_objs)[num],
642 const std::vector<sycl::event> &depends = {})
643{
644 std::size_t n_objects_held = 0;
645 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
646
647 std::size_t n_usm_owners_held = 0;
648 std::array<std::shared_ptr<void>, num> shp_usm{};
649
650 for (std::size_t i = 0; i < num; ++i) {
651 const auto &py_obj_i = py_objs[i];
652 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
653 const auto &shp =
654 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
655 shp_usm[n_usm_owners_held] = shp;
656 ++n_usm_owners_held;
657 }
658 else {
659 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
660 shp_arr[n_objects_held]->inc_ref();
661 ++n_objects_held;
662 }
663 }
664
665 bool use_depends = true;
666 sycl::event host_task_ev;
667
668 if (n_usm_owners_held > 0) {
669 host_task_ev = q.submit([&](sycl::handler &cgh) {
670 if (use_depends) {
671 cgh.depends_on(depends);
672 use_depends = false;
673 }
674 else {
675 cgh.depends_on(host_task_ev);
676 }
677 cgh.host_task([shp_usm = std::move(shp_usm)]() {
678 // no body, but shared pointers are captured in
679 // the lambda, ensuring that USM allocation is
680 // kept alive
681 });
682 });
683 }
684
685 if (n_objects_held > 0) {
686 host_task_ev = q.submit([&](sycl::handler &cgh) {
687 if (use_depends) {
688 cgh.depends_on(depends);
689 use_depends = false;
690 }
691 else {
692 cgh.depends_on(host_task_ev);
693 }
694 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
695 py::gil_scoped_acquire acquire;
696
697 for (std::size_t i = 0; i < n_objects_held; ++i) {
698 shp_arr[i]->dec_ref();
699 }
700 });
701 });
702 }
703
704 return host_task_ev;
705}
706
707// add to namespace for convenience
708using ::dpctl::utils::queues_are_compatible;
709
712template <std::size_t num>
713bool queues_are_compatible(const sycl::queue &exec_q,
714 const tensor::usm_ndarray (&arrs)[num])
715{
716 for (std::size_t i = 0; i < num; ++i) {
717
718 if (exec_q != arrs[i].get_queue()) {
719 return false;
720 }
721 }
722 return true;
723}
724} // end namespace utils
725} // end namespace dpnp
py::object get_usm_data() const
Get usm_data property of array.