DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpnp tensor C-API (provides unified access to both dpctl and dpnp
32// tensor) This includes:
33// - dpctl C-API (from external dpctl package - SYCL interface)
34// - dpnp tensor C-API (tensor interface: usm_ndarray)
35#include "dpnp_tensor_capi.h"
36
37#include <array>
38#include <cassert>
39#include <complex>
40#include <cstddef> // for std::size_t for C++ linkage
41#include <cstdint>
42#include <memory>
43#include <stddef.h> // for size_t for C linkage
44#include <stdexcept>
45#include <type_traits>
46#include <utility>
47#include <vector>
48
49#include <pybind11/pybind11.h>
50
51#include <sycl/sycl.hpp>
52
53namespace py = pybind11;
54
55namespace dpctl
56{
57namespace detail
58{
59// Lookup a type according to its size, and return a value corresponding to the
60// NumPy typenum.
61template <typename Concrete>
62constexpr int platform_typeid_lookup()
63{
64 return -1;
65}
66
67template <typename Concrete, typename T, typename... Ts, typename... Ints>
68constexpr int platform_typeid_lookup(int I, Ints... Is)
69{
70 return sizeof(Concrete) == sizeof(T)
71 ? I
72 : platform_typeid_lookup<Concrete, Ts...>(Is...);
73}
74
76{
77public:
78 // dpctl type objects
79 PyTypeObject *Py_SyclDeviceType_;
80 PyTypeObject *PySyclDeviceType_;
81 PyTypeObject *Py_SyclContextType_;
82 PyTypeObject *PySyclContextType_;
83 PyTypeObject *Py_SyclEventType_;
84 PyTypeObject *PySyclEventType_;
85 PyTypeObject *Py_SyclQueueType_;
86 PyTypeObject *PySyclQueueType_;
87 PyTypeObject *Py_MemoryType_;
88 PyTypeObject *PyMemoryUSMDeviceType_;
89 PyTypeObject *PyMemoryUSMSharedType_;
90 PyTypeObject *PyMemoryUSMHostType_;
91 PyTypeObject *PyUSMArrayType_;
92 PyTypeObject *PySyclProgramType_;
93 PyTypeObject *PySyclKernelType_;
94
95 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
96 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
97
98 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
99 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
100
101 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
102 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
103
104 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
105 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
106
107 // memory
108 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
109 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
110 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
111 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
112 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
113 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
114 size_t,
115 DPCTLSyclQueueRef,
116 PyObject *);
117
118 // program
119 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
120 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
121
122 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
123 PySyclProgramObject *);
124 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
125
126 // tensor
127 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
128 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
129 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
130 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
131 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
132 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
133 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
134 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
135 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
136 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
137 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
138 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
139 const py::ssize_t *,
140 int,
141 Py_MemoryObject *,
142 py::ssize_t,
143 char);
144 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
145 int,
146 DPCTLSyclUSMRef,
147 DPCTLSyclQueueRef,
148 PyObject *);
149 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
150 const py::ssize_t *,
151 int,
152 const py::ssize_t *,
153 DPCTLSyclUSMRef,
154 DPCTLSyclQueueRef,
155 py::ssize_t,
156 PyObject *);
157
158 int USM_ARRAY_C_CONTIGUOUS_;
159 int USM_ARRAY_F_CONTIGUOUS_;
160 int USM_ARRAY_WRITABLE_;
161 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
162 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
163 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
164 UAR_HALF_;
165 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
166 UAR_INT64_, UAR_UINT64_;
167
168 bool PySyclDevice_Check_(PyObject *obj) const
169 {
170 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
171 }
172 bool PySyclContext_Check_(PyObject *obj) const
173 {
174 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
175 }
176 bool PySyclEvent_Check_(PyObject *obj) const
177 {
178 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
179 }
180 bool PySyclQueue_Check_(PyObject *obj) const
181 {
182 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
183 }
184 bool PySyclKernel_Check_(PyObject *obj) const
185 {
186 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
187 }
188 bool PySyclProgram_Check_(PyObject *obj) const
189 {
190 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
191 }
192
194 {
195 as_usm_memory_.reset();
196 default_usm_ndarray_.reset();
197 default_usm_memory_.reset();
198 default_sycl_queue_.reset();
199 };
200
201 static auto &get()
202 {
203 static dpctl_capi api{};
204 return api;
205 }
206
207 py::object default_sycl_queue_pyobj()
208 {
209 return *default_sycl_queue_;
210 }
211 py::object default_usm_memory_pyobj()
212 {
213 return *default_usm_memory_;
214 }
215 py::object default_usm_ndarray_pyobj()
216 {
217 return *default_usm_ndarray_;
218 }
219 py::object as_usm_memory_pyobj()
220 {
221 return *as_usm_memory_;
222 }
223
224private:
225 struct Deleter
226 {
227 void operator()(py::object *p) const
228 {
229 const bool initialized = Py_IsInitialized();
230#if PY_VERSION_HEX < 0x30d0000
231 const bool finalizing = _Py_IsFinalizing();
232#else
233 const bool finalizing = Py_IsFinalizing();
234#endif
235 const bool guard = initialized && !finalizing;
236
237 if (guard) {
238 delete p;
239 }
240 }
241 };
242
243 std::shared_ptr<py::object> default_sycl_queue_;
244 std::shared_ptr<py::object> default_usm_memory_;
245 std::shared_ptr<py::object> default_usm_ndarray_;
246 std::shared_ptr<py::object> as_usm_memory_;
247
248 dpctl_capi()
249 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
250 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
251 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
252 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
253 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
254 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
255 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
256 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
257 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
258 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
259 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
260 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
261 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
262 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
263 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
264 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
265 SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
266 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
267 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
268 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
269 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
270 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
271 UsmNDArray_MakeSimpleFromMemory_(nullptr),
272 UsmNDArray_MakeSimpleFromPtr_(nullptr),
273 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
274 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
275 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
276 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
277 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
278 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
279 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
280 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
281 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
282 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
283
284 {
285 // Import Cython-generated C-API for dpnp tensor
286 // This imports python modules and initializes
287 // static variables such as function pointers for C-API,
288 // e.g. SyclDevice_GetDeviceRef, etc.
289 // pointers to Python types, i.e. PySyclDeviceType, etc.
290 // and exported constants, i.e. USM_ARRAY_C_CONTIGUOUS, etc.
291 import_dpnp_tensor(); // Imports both dpctl and dpnp tensor C-APIs
292
293 // Python type objects for classes implemented by dpctl
294 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
295 this->PySyclDeviceType_ = &PySyclDeviceType;
296 this->Py_SyclContextType_ = &Py_SyclContextType;
297 this->PySyclContextType_ = &PySyclContextType;
298 this->Py_SyclEventType_ = &Py_SyclEventType;
299 this->PySyclEventType_ = &PySyclEventType;
300 this->Py_SyclQueueType_ = &Py_SyclQueueType;
301 this->PySyclQueueType_ = &PySyclQueueType;
302 this->Py_MemoryType_ = &Py_MemoryType;
303 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
304 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
305 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
306 this->PyUSMArrayType_ = &PyUSMArrayType;
307 this->PySyclProgramType_ = &PySyclProgramType;
308 this->PySyclKernelType_ = &PySyclKernelType;
309
310 // SyclDevice API
311 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
312 this->SyclDevice_Make_ = SyclDevice_Make;
313
314 // SyclContext API
315 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
316 this->SyclContext_Make_ = SyclContext_Make;
317
318 // SyclEvent API
319 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
320 this->SyclEvent_Make_ = SyclEvent_Make;
321
322 // SyclQueue API
323 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
324 this->SyclQueue_Make_ = SyclQueue_Make;
325
326 // dpctl.memory API
327 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
328 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
329 this->Memory_GetContextRef_ = Memory_GetContextRef;
330 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
331 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
332 this->Memory_Make_ = Memory_Make;
333
334 // dpctl.program API
335 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
336 this->SyclKernel_Make_ = SyclKernel_Make;
337 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
338 this->SyclProgram_Make_ = SyclProgram_Make;
339
340 // dpctl.tensor.usm_ndarray API
341 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
342 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
343 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
344 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
345 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
346 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
347 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
348 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
349 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
350 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
351 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
352 this->UsmNDArray_MakeSimpleFromMemory_ =
353 UsmNDArray_MakeSimpleFromMemory;
354 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
355 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
356
357 // constants
358 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
359 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
360 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
361 this->UAR_BOOL_ = UAR_BOOL;
362 this->UAR_BYTE_ = UAR_BYTE;
363 this->UAR_UBYTE_ = UAR_UBYTE;
364 this->UAR_SHORT_ = UAR_SHORT;
365 this->UAR_USHORT_ = UAR_USHORT;
366 this->UAR_INT_ = UAR_INT;
367 this->UAR_UINT_ = UAR_UINT;
368 this->UAR_LONG_ = UAR_LONG;
369 this->UAR_ULONG_ = UAR_ULONG;
370 this->UAR_LONGLONG_ = UAR_LONGLONG;
371 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
372 this->UAR_FLOAT_ = UAR_FLOAT;
373 this->UAR_DOUBLE_ = UAR_DOUBLE;
374 this->UAR_CFLOAT_ = UAR_CFLOAT;
375 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
376 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
377 this->UAR_HALF_ = UAR_HALF;
378
379 // deduced disjoint types
380 this->UAR_INT8_ = UAR_BYTE;
381 this->UAR_UINT8_ = UAR_UBYTE;
382 this->UAR_INT16_ = UAR_SHORT;
383 this->UAR_UINT16_ = UAR_USHORT;
384 this->UAR_INT32_ =
385 platform_typeid_lookup<std::int32_t, long, int, short>(
386 UAR_LONG, UAR_INT, UAR_SHORT);
387 this->UAR_UINT32_ =
388 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
389 unsigned short>(UAR_ULONG, UAR_UINT,
390 UAR_USHORT);
391 this->UAR_INT64_ =
392 platform_typeid_lookup<std::int64_t, long, long long, int>(
393 UAR_LONG, UAR_LONGLONG, UAR_INT);
394 this->UAR_UINT64_ =
395 platform_typeid_lookup<std::uint64_t, unsigned long,
396 unsigned long long, unsigned int>(
397 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
398
399 // create shared pointers to python objects used in type-casters
400 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
401 sycl::queue q_{};
402 PySyclQueueObject *py_q_tmp =
403 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
404 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
405 reinterpret_cast<PyObject *>(py_q_tmp));
406
407 default_sycl_queue_ = std::shared_ptr<py::object>(
408 new py::object(py_sycl_queue), Deleter{});
409
410 py::module_ mod_memory = py::module_::import("dpctl.memory");
411 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
412 as_usm_memory_ = std::shared_ptr<py::object>(
413 new py::object{py_as_usm_memory}, Deleter{});
414
415 auto mem_kl = mod_memory.attr("MemoryUSMHost");
416 const py::object &py_default_usm_memory =
417 mem_kl(1, py::arg("queue") = py_sycl_queue);
418 default_usm_memory_ = std::shared_ptr<py::object>(
419 new py::object{py_default_usm_memory}, Deleter{});
420
421 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
422 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
423
424 const py::object &py_default_usm_ndarray =
425 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
426 py::arg("buffer") = py_default_usm_memory);
427
428 default_usm_ndarray_ = std::shared_ptr<py::object>(
429 new py::object{py_default_usm_ndarray}, Deleter{});
430 }
431
432 dpctl_capi(dpctl_capi const &) = default;
433 dpctl_capi &operator=(dpctl_capi const &) = default;
434 dpctl_capi &operator=(dpctl_capi &&) = default;
435
436}; // struct dpctl_capi
437} // namespace detail
438} // namespace dpctl
439
440namespace pybind11::detail
441{
442#define DPCTL_TYPE_CASTER(type, py_name) \
443protected: \
444 std::unique_ptr<type> value; \
445 \
446public: \
447 static constexpr auto name = py_name; \
448 template < \
449 typename T_, \
450 ::pybind11::detail::enable_if_t< \
451 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
452 int> = 0> \
453 static ::pybind11::handle cast(T_ *src, \
454 ::pybind11::return_value_policy policy, \
455 ::pybind11::handle parent) \
456 { \
457 if (!src) \
458 return ::pybind11::none().release(); \
459 if (policy == ::pybind11::return_value_policy::take_ownership) { \
460 auto h = cast(std::move(*src), policy, parent); \
461 delete src; \
462 return h; \
463 } \
464 return cast(*src, policy, parent); \
465 } \
466 operator type *() \
467 { \
468 return value.get(); \
469 } /* NOLINT(bugprone-macro-parentheses) */ \
470 operator type &() \
471 { \
472 return *value; \
473 } /* NOLINT(bugprone-macro-parentheses) */ \
474 operator type &&() && \
475 { \
476 return std::move(*value); \
477 } /* NOLINT(bugprone-macro-parentheses) */ \
478 template <typename T_> \
479 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
480
481/* This type caster associates ``sycl::queue`` C++ class with
482 * :class:`dpctl.SyclQueue` for the purposes of generation of
483 * Python bindings by pybind11.
484 */
485template <>
486struct type_caster<sycl::queue>
487{
488public:
489 bool load(handle src, bool)
490 {
491 PyObject *source = src.ptr();
492 auto const &api = ::dpctl::detail::dpctl_capi::get();
493 if (api.PySyclQueue_Check_(source)) {
494 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
495 reinterpret_cast<PySyclQueueObject *>(source));
496 value = std::make_unique<sycl::queue>(
497 *(reinterpret_cast<sycl::queue *>(QRef)));
498 return true;
499 }
500 else {
501 throw py::type_error(
502 "Input is of unexpected type, expected dpctl.SyclQueue");
503 }
504 }
505
506 static handle cast(sycl::queue src, return_value_policy, handle)
507 {
508 auto const &api = ::dpctl::detail::dpctl_capi::get();
509 auto tmp =
510 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
511 return handle(reinterpret_cast<PyObject *>(tmp));
512 }
513
514 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
515};
516
517/* This type caster associates ``sycl::device`` C++ class with
518 * :class:`dpctl.SyclDevice` for the purposes of generation of
519 * Python bindings by pybind11.
520 */
521template <>
522struct type_caster<sycl::device>
523{
524public:
525 bool load(handle src, bool)
526 {
527 PyObject *source = src.ptr();
528 auto const &api = ::dpctl::detail::dpctl_capi::get();
529 if (api.PySyclDevice_Check_(source)) {
530 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
531 reinterpret_cast<PySyclDeviceObject *>(source));
532 value = std::make_unique<sycl::device>(
533 *(reinterpret_cast<sycl::device *>(DRef)));
534 return true;
535 }
536 else {
537 throw py::type_error(
538 "Input is of unexpected type, expected dpctl.SyclDevice");
539 }
540 }
541
542 static handle cast(sycl::device src, return_value_policy, handle)
543 {
544 auto const &api = ::dpctl::detail::dpctl_capi::get();
545 auto tmp =
546 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
547 return handle(reinterpret_cast<PyObject *>(tmp));
548 }
549
550 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
551};
552
553/* This type caster associates ``sycl::context`` C++ class with
554 * :class:`dpctl.SyclContext` for the purposes of generation of
555 * Python bindings by pybind11.
556 */
557template <>
558struct type_caster<sycl::context>
559{
560public:
561 bool load(handle src, bool)
562 {
563 PyObject *source = src.ptr();
564 auto const &api = ::dpctl::detail::dpctl_capi::get();
565 if (api.PySyclContext_Check_(source)) {
566 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
567 reinterpret_cast<PySyclContextObject *>(source));
568 value = std::make_unique<sycl::context>(
569 *(reinterpret_cast<sycl::context *>(CRef)));
570 return true;
571 }
572 else {
573 throw py::type_error(
574 "Input is of unexpected type, expected dpctl.SyclContext");
575 }
576 }
577
578 static handle cast(sycl::context src, return_value_policy, handle)
579 {
580 auto const &api = ::dpctl::detail::dpctl_capi::get();
581 auto tmp =
582 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
583 return handle(reinterpret_cast<PyObject *>(tmp));
584 }
585
586 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
587};
588
589/* This type caster associates ``sycl::event`` C++ class with
590 * :class:`dpctl.SyclEvent` for the purposes of generation of
591 * Python bindings by pybind11.
592 */
593template <>
594struct type_caster<sycl::event>
595{
596public:
597 bool load(handle src, bool)
598 {
599 PyObject *source = src.ptr();
600 auto const &api = ::dpctl::detail::dpctl_capi::get();
601 if (api.PySyclEvent_Check_(source)) {
602 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
603 reinterpret_cast<PySyclEventObject *>(source));
604 value = std::make_unique<sycl::event>(
605 *(reinterpret_cast<sycl::event *>(ERef)));
606 return true;
607 }
608 else {
609 throw py::type_error(
610 "Input is of unexpected type, expected dpctl.SyclEvent");
611 }
612 }
613
614 static handle cast(sycl::event src, return_value_policy, handle)
615 {
616 auto const &api = ::dpctl::detail::dpctl_capi::get();
617 auto tmp =
618 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
619 return handle(reinterpret_cast<PyObject *>(tmp));
620 }
621
622 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
623};
624
625/* This type caster associates ``sycl::kernel`` C++ class with
626 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
627 * Python bindings by pybind11.
628 */
629template <>
630struct type_caster<sycl::kernel>
631{
632public:
633 bool load(handle src, bool)
634 {
635 PyObject *source = src.ptr();
636 auto const &api = ::dpctl::detail::dpctl_capi::get();
637 if (api.PySyclKernel_Check_(source)) {
638 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
639 reinterpret_cast<PySyclKernelObject *>(source));
640 value = std::make_unique<sycl::kernel>(
641 *(reinterpret_cast<sycl::kernel *>(KRef)));
642 return true;
643 }
644 else {
645 throw py::type_error("Input is of unexpected type, expected "
646 "dpctl.program.SyclKernel");
647 }
648 }
649
650 static handle cast(sycl::kernel src, return_value_policy, handle)
651 {
652 auto const &api = ::dpctl::detail::dpctl_capi::get();
653 auto tmp =
654 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
655 "dpctl4pybind11_kernel");
656 return handle(reinterpret_cast<PyObject *>(tmp));
657 }
658
659 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
660};
661
662/* This type caster associates
663 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
664 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
665 * Python bindings by pybind11.
666 */
667template <>
668struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
669{
670public:
671 bool load(handle src, bool)
672 {
673 PyObject *source = src.ptr();
674 auto const &api = ::dpctl::detail::dpctl_capi::get();
675 if (api.PySyclProgram_Check_(source)) {
676 DPCTLSyclKernelBundleRef KBRef =
677 api.SyclProgram_GetKernelBundleRef_(
678 reinterpret_cast<PySyclProgramObject *>(source));
679 value = std::make_unique<
680 sycl::kernel_bundle<sycl::bundle_state::executable>>(
681 *(reinterpret_cast<
682 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
683 KBRef)));
684 return true;
685 }
686 else {
687 throw py::type_error("Input is of unexpected type, expected "
688 "dpctl.program.SyclProgram");
689 }
690 }
691
692 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
693 return_value_policy,
694 handle)
695 {
696 auto const &api = ::dpctl::detail::dpctl_capi::get();
697 auto tmp = api.SyclProgram_Make_(
698 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
699 return handle(reinterpret_cast<PyObject *>(tmp));
700 }
701
702 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
703 _("dpctl.program.SyclProgram"));
704};
705
706/* This type caster associates
707 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
708 * of generation of Python bindings by pybind11.
709 */
710template <>
711struct type_caster<sycl::half>
712{
713public:
714 bool load(handle src, bool convert)
715 {
716 double py_value;
717
718 if (!src) {
719 return false;
720 }
721
722 PyObject *source = src.ptr();
723
724 if (convert || PyFloat_Check(source)) {
725 py_value = PyFloat_AsDouble(source);
726 }
727 else {
728 return false;
729 }
730
731 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
732
733 if (py_err) {
734 PyErr_Clear();
735 if (convert && (PyNumber_Check(source) != 0)) {
736 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
737 return load(tmp, false);
738 }
739 return false;
740 }
741 value = static_cast<sycl::half>(py_value);
742 return true;
743 }
744
745 static handle cast(sycl::half src, return_value_policy, handle)
746 {
747 return PyFloat_FromDouble(static_cast<double>(src));
748 }
749
750 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
751};
752} // namespace pybind11::detail
753
754namespace dpctl
755{
756namespace memory
757{
758// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
759// this allows to avoid compilation error
760using pybind11::error_already_set;
761
762class usm_memory : public py::object
763{
764public:
765 PYBIND11_OBJECT_CVT(
767 py::object,
768 [](PyObject *o) -> bool {
769 return PyObject_TypeCheck(
770 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
771 0;
772 },
773 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
774
775 usm_memory()
776 : py::object(
777 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
778 borrowed_t{})
779 {
780 if (!m_ptr)
781 throw py::error_already_set();
782 }
783
787 usm_memory(void *usm_ptr,
788 std::size_t nbytes,
789 const sycl::queue &q,
790 std::shared_ptr<void> shptr)
791 {
792 auto const &api = ::dpctl::detail::dpctl_capi::get();
793 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
794 auto q_uptr = std::make_unique<sycl::queue>(q);
795 DPCTLSyclQueueRef QRef =
796 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
797
798 auto vacuous_destructor = []() {};
799 py::capsule mock_owner(vacuous_destructor);
800
801 // create memory object owned by mock_owner, it is a new reference
802 PyObject *_memory =
803 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
804 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
805
806 using py_uptrT =
807 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
808
809 if (!_memory) {
810 throw py::error_already_set();
811 }
812
813 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
814 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
815
816 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
817 // replace mock_owner capsule as the owner
818 memobj->refobj = Py_None;
819 // set opaque ptr field, usm_memory now knowns that USM is managed
820 // by smart pointer
821 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
822
823 // _memory will delete created copies of sycl::queue, and
824 // std::shared_ptr and the deleter of the shared_ptr<void> is
825 // supposed to free the USM allocation
826 m_ptr = _memory;
827 q_uptr.release();
828 memory_uptr.release();
829 }
830
831 sycl::queue get_queue() const
832 {
833 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
834 auto const &api = ::dpctl::detail::dpctl_capi::get();
835 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
836 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
837 return *obj_q;
838 }
839
840 char *get_pointer() const
841 {
842 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
843 auto const &api = ::dpctl::detail::dpctl_capi::get();
844 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
845 return reinterpret_cast<char *>(MRef);
846 }
847
848 std::size_t get_nbytes() const
849 {
850 auto const &api = ::dpctl::detail::dpctl_capi::get();
851 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
852 return api.Memory_GetNumBytes_(mem_obj);
853 }
854
855 bool is_managed_by_smart_ptr() const
856 {
857 auto const &api = ::dpctl::detail::dpctl_capi::get();
858 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
859 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
860
861 return bool(opaque_ptr);
862 }
863
864 const std::shared_ptr<void> &get_smart_ptr_owner() const
865 {
866 auto const &api = ::dpctl::detail::dpctl_capi::get();
867 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
868 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
869
870 if (opaque_ptr) {
871 auto shptr_ptr =
872 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
873 return *shptr_ptr;
874 }
875 else {
876 throw std::runtime_error(
877 "Memory object does not have smart pointer "
878 "managing lifetime of USM allocation");
879 }
880 }
881
882protected:
883 static PyObject *as_usm_memory(PyObject *o)
884 {
885 if (o == nullptr) {
886 PyErr_SetString(PyExc_ValueError,
887 "cannot create a usm_memory from a nullptr");
888 return nullptr;
889 }
890
891 auto converter =
892 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
893
894 py::object res;
895 try {
896 res = converter(py::handle(o));
897 } catch (const py::error_already_set &e) {
898 return nullptr;
899 }
900 return res.ptr();
901 }
902};
903} // end namespace memory
904
905namespace tensor
906{
907inline std::vector<py::ssize_t>
908 c_contiguous_strides(int nd,
909 const py::ssize_t *shape,
910 py::ssize_t element_size = 1)
911{
912 if (nd > 0) {
913 std::vector<py::ssize_t> c_strides(nd, element_size);
914 for (int ic = nd - 1; ic > 0;) {
915 py::ssize_t next_v = c_strides[ic] * shape[ic];
916 c_strides[--ic] = next_v;
917 }
918 return c_strides;
919 }
920 else {
921 return std::vector<py::ssize_t>();
922 }
923}
924
925inline std::vector<py::ssize_t>
926 f_contiguous_strides(int nd,
927 const py::ssize_t *shape,
928 py::ssize_t element_size = 1)
929{
930 if (nd > 0) {
931 std::vector<py::ssize_t> f_strides(nd, element_size);
932 for (int i = 0; i < nd - 1;) {
933 py::ssize_t next_v = f_strides[i] * shape[i];
934 f_strides[++i] = next_v;
935 }
936 return f_strides;
937 }
938 else {
939 return std::vector<py::ssize_t>();
940 }
941}
942
943inline std::vector<py::ssize_t>
944 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
945 py::ssize_t element_size = 1)
946{
947 return c_contiguous_strides(shape.size(), shape.data(), element_size);
948}
949
950inline std::vector<py::ssize_t>
951 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
952 py::ssize_t element_size = 1)
953{
954 return f_contiguous_strides(shape.size(), shape.data(), element_size);
955}
956
957class usm_ndarray : public py::object
958{
959public:
960 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
961 return PyObject_TypeCheck(
962 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
963 })
964
966 : py::object(
967 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
968 borrowed_t{})
969 {
970 if (!m_ptr)
971 throw py::error_already_set();
972 }
973
974 char *get_data() const
975 {
976 PyUSMArrayObject *raw_ar = usm_array_ptr();
977
978 auto const &api = ::dpctl::detail::dpctl_capi::get();
979 return api.UsmNDArray_GetData_(raw_ar);
980 }
981
982 template <typename T>
983 T *get_data() const
984 {
985 return reinterpret_cast<T *>(get_data());
986 }
987
988 int get_ndim() const
989 {
990 PyUSMArrayObject *raw_ar = usm_array_ptr();
991
992 auto const &api = ::dpctl::detail::dpctl_capi::get();
993 return api.UsmNDArray_GetNDim_(raw_ar);
994 }
995
996 const py::ssize_t *get_shape_raw() const
997 {
998 PyUSMArrayObject *raw_ar = usm_array_ptr();
999
1000 auto const &api = ::dpctl::detail::dpctl_capi::get();
1001 return api.UsmNDArray_GetShape_(raw_ar);
1002 }
1003
1004 std::vector<py::ssize_t> get_shape_vector() const
1005 {
1006 auto raw_sh = get_shape_raw();
1007 auto nd = get_ndim();
1008
1009 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
1010 return shape_vector;
1011 }
1012
1013 py::ssize_t get_shape(int i) const
1014 {
1015 auto shape_ptr = get_shape_raw();
1016 return shape_ptr[i];
1017 }
1018
1019 const py::ssize_t *get_strides_raw() const
1020 {
1021 PyUSMArrayObject *raw_ar = usm_array_ptr();
1022
1023 auto const &api = ::dpctl::detail::dpctl_capi::get();
1024 return api.UsmNDArray_GetStrides_(raw_ar);
1025 }
1026
1027 std::vector<py::ssize_t> get_strides_vector() const
1028 {
1029 auto raw_st = get_strides_raw();
1030 auto nd = get_ndim();
1031
1032 if (raw_st == nullptr) {
1033 auto is_c_contig = is_c_contiguous();
1034 auto is_f_contig = is_f_contiguous();
1035 auto raw_sh = get_shape_raw();
1036 if (is_c_contig) {
1037 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1038 return contig_strides;
1039 }
1040 else if (is_f_contig) {
1041 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1042 return contig_strides;
1043 }
1044 else {
1045 throw std::runtime_error("Invalid array encountered when "
1046 "building strides");
1047 }
1048 }
1049 else {
1050 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1051 return st_vec;
1052 }
1053 }
1054
1055 py::ssize_t get_size() const
1056 {
1057 PyUSMArrayObject *raw_ar = usm_array_ptr();
1058
1059 auto const &api = ::dpctl::detail::dpctl_capi::get();
1060 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
1061 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1062
1063 py::ssize_t nelems = 1;
1064 for (int i = 0; i < ndim; ++i) {
1065 nelems *= shape[i];
1066 }
1067
1068 assert(nelems >= 0);
1069 return nelems;
1070 }
1071
1072 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1073 {
1074 PyUSMArrayObject *raw_ar = usm_array_ptr();
1075
1076 auto const &api = ::dpctl::detail::dpctl_capi::get();
1077 int nd = api.UsmNDArray_GetNDim_(raw_ar);
1078 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1079 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
1080
1081 py::ssize_t offset_min = 0;
1082 py::ssize_t offset_max = 0;
1083 if (strides == nullptr) {
1084 py::ssize_t stride(1);
1085 for (int i = 0; i < nd; ++i) {
1086 offset_max += stride * (shape[i] - 1);
1087 stride *= shape[i];
1088 }
1089 }
1090 else {
1091 for (int i = 0; i < nd; ++i) {
1092 py::ssize_t delta = strides[i] * (shape[i] - 1);
1093 if (strides[i] > 0) {
1094 offset_max += delta;
1095 }
1096 else {
1097 offset_min += delta;
1098 }
1099 }
1100 }
1101 return std::make_pair(offset_min, offset_max);
1102 }
1103
1104 sycl::queue get_queue() const
1105 {
1106 PyUSMArrayObject *raw_ar = usm_array_ptr();
1107
1108 auto const &api = ::dpctl::detail::dpctl_capi::get();
1109 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1110 return *(reinterpret_cast<sycl::queue *>(QRef));
1111 }
1112
1113 sycl::device get_device() const
1114 {
1115 PyUSMArrayObject *raw_ar = usm_array_ptr();
1116
1117 auto const &api = ::dpctl::detail::dpctl_capi::get();
1118 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1119 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1120 }
1121
1122 int get_typenum() const
1123 {
1124 PyUSMArrayObject *raw_ar = usm_array_ptr();
1125
1126 auto const &api = ::dpctl::detail::dpctl_capi::get();
1127 return api.UsmNDArray_GetTypenum_(raw_ar);
1128 }
1129
1130 int get_flags() const
1131 {
1132 PyUSMArrayObject *raw_ar = usm_array_ptr();
1133
1134 auto const &api = ::dpctl::detail::dpctl_capi::get();
1135 return api.UsmNDArray_GetFlags_(raw_ar);
1136 }
1137
1138 int get_elemsize() const
1139 {
1140 PyUSMArrayObject *raw_ar = usm_array_ptr();
1141
1142 auto const &api = ::dpctl::detail::dpctl_capi::get();
1143 return api.UsmNDArray_GetElementSize_(raw_ar);
1144 }
1145
1146 bool is_c_contiguous() const
1147 {
1148 int flags = get_flags();
1149 auto const &api = ::dpctl::detail::dpctl_capi::get();
1150 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1151 }
1152
1153 bool is_f_contiguous() const
1154 {
1155 int flags = get_flags();
1156 auto const &api = ::dpctl::detail::dpctl_capi::get();
1157 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1158 }
1159
1160 bool is_writable() const
1161 {
1162 int flags = get_flags();
1163 auto const &api = ::dpctl::detail::dpctl_capi::get();
1164 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1165 }
1166
1168 py::object get_usm_data() const
1169 {
1170 PyUSMArrayObject *raw_ar = usm_array_ptr();
1171
1172 auto const &api = ::dpctl::detail::dpctl_capi::get();
1173 // UsmNDArray_GetUSMData_ gives a new reference
1174 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1175
1176 // pass reference ownership to py::object
1177 return py::reinterpret_steal<py::object>(usm_data);
1178 }
1179
1180 bool is_managed_by_smart_ptr() const
1181 {
1182 PyUSMArrayObject *raw_ar = usm_array_ptr();
1183
1184 auto const &api = ::dpctl::detail::dpctl_capi::get();
1185 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1186
1187 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1188 Py_DECREF(usm_data);
1189 return false;
1190 }
1191
1192 Py_MemoryObject *mem_obj =
1193 reinterpret_cast<Py_MemoryObject *>(usm_data);
1194 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1195
1196 Py_DECREF(usm_data);
1197 return bool(opaque_ptr);
1198 }
1199
1200 const std::shared_ptr<void> &get_smart_ptr_owner() const
1201 {
1202 PyUSMArrayObject *raw_ar = usm_array_ptr();
1203
1204 auto const &api = ::dpctl::detail::dpctl_capi::get();
1205
1206 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1207
1208 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1209 Py_DECREF(usm_data);
1210 throw std::runtime_error(
1211 "usm_ndarray object does not have Memory object "
1212 "managing lifetime of USM allocation");
1213 }
1214
1215 Py_MemoryObject *mem_obj =
1216 reinterpret_cast<Py_MemoryObject *>(usm_data);
1217 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1218 Py_DECREF(usm_data);
1219
1220 if (opaque_ptr) {
1221 auto shptr_ptr =
1222 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1223 return *shptr_ptr;
1224 }
1225 else {
1226 throw std::runtime_error(
1227 "Memory object underlying usm_ndarray does not have "
1228 "smart pointer managing lifetime of USM allocation");
1229 }
1230 }
1231
1232private:
1233 PyUSMArrayObject *usm_array_ptr() const
1234 {
1235 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1236 }
1237};
1238} // end namespace tensor
1239
1240namespace utils
1241{
1242namespace detail
1243{
1245{
1246
1247 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1248 {
1249 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1250 const auto &usm_memory_inst =
1251 py::cast<dpctl::memory::usm_memory>(h);
1252 return usm_memory_inst.is_managed_by_smart_ptr();
1253 }
1254 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1255 const auto &usm_array_inst =
1256 py::cast<dpctl::tensor::usm_ndarray>(h);
1257 return usm_array_inst.is_managed_by_smart_ptr();
1258 }
1259
1260 return false;
1261 }
1262
1263 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1264 {
1265 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1266 const auto &usm_memory_inst =
1267 py::cast<dpctl::memory::usm_memory>(h);
1268 return usm_memory_inst.get_smart_ptr_owner();
1269 }
1270 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1271 const auto &usm_array_inst =
1272 py::cast<dpctl::tensor::usm_ndarray>(h);
1273 return usm_array_inst.get_smart_ptr_owner();
1274 }
1275
1276 throw std::runtime_error(
1277 "Attempted extraction of shared_ptr on an unrecognized type");
1278 }
1279};
1280} // end of namespace detail
1281
1282template <std::size_t num>
1283sycl::event keep_args_alive(sycl::queue &q,
1284 const py::object (&py_objs)[num],
1285 const std::vector<sycl::event> &depends = {})
1286{
1287 std::size_t n_objects_held = 0;
1288 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1289
1290 std::size_t n_usm_owners_held = 0;
1291 std::array<std::shared_ptr<void>, num> shp_usm{};
1292
1293 for (std::size_t i = 0; i < num; ++i) {
1294 const auto &py_obj_i = py_objs[i];
1295 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1296 const auto &shp =
1297 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1298 shp_usm[n_usm_owners_held] = shp;
1299 ++n_usm_owners_held;
1300 }
1301 else {
1302 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1303 shp_arr[n_objects_held]->inc_ref();
1304 ++n_objects_held;
1305 }
1306 }
1307
1308 bool use_depends = true;
1309 sycl::event host_task_ev;
1310
1311 if (n_usm_owners_held > 0) {
1312 host_task_ev = q.submit([&](sycl::handler &cgh) {
1313 if (use_depends) {
1314 cgh.depends_on(depends);
1315 use_depends = false;
1316 }
1317 else {
1318 cgh.depends_on(host_task_ev);
1319 }
1320 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1321 // no body, but shared pointers are captured in
1322 // the lambda, ensuring that USM allocation is
1323 // kept alive
1324 });
1325 });
1326 }
1327
1328 if (n_objects_held > 0) {
1329 host_task_ev = q.submit([&](sycl::handler &cgh) {
1330 if (use_depends) {
1331 cgh.depends_on(depends);
1332 use_depends = false;
1333 }
1334 else {
1335 cgh.depends_on(host_task_ev);
1336 }
1337 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1338 py::gil_scoped_acquire acquire;
1339
1340 for (std::size_t i = 0; i < n_objects_held; ++i) {
1341 shp_arr[i]->dec_ref();
1342 }
1343 });
1344 });
1345 }
1346
1347 return host_task_ev;
1348}
1349
1352template <std::size_t num>
1353bool queues_are_compatible(const sycl::queue &exec_q,
1354 const sycl::queue (&alloc_qs)[num])
1355{
1356 for (std::size_t i = 0; i < num; ++i) {
1357
1358 if (exec_q != alloc_qs[i]) {
1359 return false;
1360 }
1361 }
1362 return true;
1363}
1364
1367template <std::size_t num>
1368bool queues_are_compatible(const sycl::queue &exec_q,
1369 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1370{
1371 for (std::size_t i = 0; i < num; ++i) {
1372
1373 if (exec_q != arrs[i].get_queue()) {
1374 return false;
1375 }
1376 }
1377 return true;
1378}
1379} // end namespace utils
1380} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.