DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpctl_ext C-API (provides unified access to both dpctl and dpctl_ext)
32// This includes:
33// - dpctl C-API (from external dpctl package - SYCL interface)
34// - dpctl_ext C-API (tensor interface: usm_ndarray)
35//
36// TODO: When dpctl_ext is renamed to dpctl.tensor:
37// - Update include: "dpctl_ext_capi.h" → "dpctl/tensor/tensor_capi.h"
38// (Use tensor_capi.h, NOT dpctl_capi.h, to avoid conflict with external
39// dpctl)
40// - Update import calls: import_dpctl_ext() → import_dpctl_tensor()
41#include "dpctl_ext_capi.h"
42
43#include <array>
44#include <cassert>
45#include <complex>
46#include <cstddef> // for std::size_t for C++ linkage
47#include <cstdint>
48#include <memory>
49#include <stddef.h> // for size_t for C linkage
50#include <stdexcept>
51#include <type_traits>
52#include <utility>
53#include <vector>
54
55#include <pybind11/pybind11.h>
56
57#include <sycl/sycl.hpp>
58
59namespace py = pybind11;
60
61namespace dpctl
62{
63namespace detail
64{
65// Lookup a type according to its size, and return a value corresponding to the
66// NumPy typenum.
67template <typename Concrete>
68constexpr int platform_typeid_lookup()
69{
70 return -1;
71}
72
73template <typename Concrete, typename T, typename... Ts, typename... Ints>
74constexpr int platform_typeid_lookup(int I, Ints... Is)
75{
76 return sizeof(Concrete) == sizeof(T)
77 ? I
78 : platform_typeid_lookup<Concrete, Ts...>(Is...);
79}
80
82{
83public:
84 // dpctl type objects
85 PyTypeObject *Py_SyclDeviceType_;
86 PyTypeObject *PySyclDeviceType_;
87 PyTypeObject *Py_SyclContextType_;
88 PyTypeObject *PySyclContextType_;
89 PyTypeObject *Py_SyclEventType_;
90 PyTypeObject *PySyclEventType_;
91 PyTypeObject *Py_SyclQueueType_;
92 PyTypeObject *PySyclQueueType_;
93 PyTypeObject *Py_MemoryType_;
94 PyTypeObject *PyMemoryUSMDeviceType_;
95 PyTypeObject *PyMemoryUSMSharedType_;
96 PyTypeObject *PyMemoryUSMHostType_;
97 PyTypeObject *PyUSMArrayType_;
98 PyTypeObject *PySyclProgramType_;
99 PyTypeObject *PySyclKernelType_;
100
101 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
102 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
103
104 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
105 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
106
107 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
108 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
109
110 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
111 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
112
113 // memory
114 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
115 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
116 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
117 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
118 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
119 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
120 size_t,
121 DPCTLSyclQueueRef,
122 PyObject *);
123
124 // program
125 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
126 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
127
128 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
129 PySyclProgramObject *);
130 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
131
132 // tensor
133 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
134 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
135 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
136 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
137 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
138 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
139 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
140 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
141 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
142 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
143 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
144 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
145 const py::ssize_t *,
146 int,
147 Py_MemoryObject *,
148 py::ssize_t,
149 char);
150 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
151 int,
152 DPCTLSyclUSMRef,
153 DPCTLSyclQueueRef,
154 PyObject *);
155 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
156 const py::ssize_t *,
157 int,
158 const py::ssize_t *,
159 DPCTLSyclUSMRef,
160 DPCTLSyclQueueRef,
161 py::ssize_t,
162 PyObject *);
163
164 int USM_ARRAY_C_CONTIGUOUS_;
165 int USM_ARRAY_F_CONTIGUOUS_;
166 int USM_ARRAY_WRITABLE_;
167 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
168 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
169 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
170 UAR_HALF_;
171 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
172 UAR_INT64_, UAR_UINT64_;
173
174 bool PySyclDevice_Check_(PyObject *obj) const
175 {
176 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
177 }
178 bool PySyclContext_Check_(PyObject *obj) const
179 {
180 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
181 }
182 bool PySyclEvent_Check_(PyObject *obj) const
183 {
184 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
185 }
186 bool PySyclQueue_Check_(PyObject *obj) const
187 {
188 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
189 }
190 bool PySyclKernel_Check_(PyObject *obj) const
191 {
192 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
193 }
194 bool PySyclProgram_Check_(PyObject *obj) const
195 {
196 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
197 }
198
200 {
201 as_usm_memory_.reset();
202 default_usm_ndarray_.reset();
203 default_usm_memory_.reset();
204 default_sycl_queue_.reset();
205 };
206
207 static auto &get()
208 {
209 static dpctl_capi api{};
210 return api;
211 }
212
213 py::object default_sycl_queue_pyobj()
214 {
215 return *default_sycl_queue_;
216 }
217 py::object default_usm_memory_pyobj()
218 {
219 return *default_usm_memory_;
220 }
221 py::object default_usm_ndarray_pyobj()
222 {
223 return *default_usm_ndarray_;
224 }
225 py::object as_usm_memory_pyobj()
226 {
227 return *as_usm_memory_;
228 }
229
230private:
231 struct Deleter
232 {
233 void operator()(py::object *p) const
234 {
235 const bool initialized = Py_IsInitialized();
236#if PY_VERSION_HEX < 0x30d0000
237 const bool finalizing = _Py_IsFinalizing();
238#else
239 const bool finalizing = Py_IsFinalizing();
240#endif
241 const bool guard = initialized && !finalizing;
242
243 if (guard) {
244 delete p;
245 }
246 }
247 };
248
249 std::shared_ptr<py::object> default_sycl_queue_;
250 std::shared_ptr<py::object> default_usm_memory_;
251 std::shared_ptr<py::object> default_usm_ndarray_;
252 std::shared_ptr<py::object> as_usm_memory_;
253
254 dpctl_capi()
255 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
256 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
257 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
258 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
259 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
260 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
261 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
262 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
263 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
264 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
265 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
266 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
267 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
268 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
269 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
270 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
271 SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
272 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
273 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
274 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
275 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
276 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
277 UsmNDArray_MakeSimpleFromMemory_(nullptr),
278 UsmNDArray_MakeSimpleFromPtr_(nullptr),
279 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
280 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
281 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
282 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
283 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
284 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
285 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
286 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
287 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
288 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
289
290 {
291 // Import Cython-generated C-API for dpctl
292 // This imports python modules and initializes
293 // static variables such as function pointers for C-API,
294 // e.g. SyclDevice_GetDeviceRef, etc.
295 // pointers to Python types, i.e. PySyclDeviceType, etc.
296 // and exported constants, i.e. USM_ARRAY_C_CONTIGUOUS, etc.
297 // TODO: rename once dpctl_ext is renamed
298 import_dpctl_ext(); // Imports both dpctl and dpctl_ext C-APIs
299
300 // Python type objects for classes implemented by dpctl
301 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
302 this->PySyclDeviceType_ = &PySyclDeviceType;
303 this->Py_SyclContextType_ = &Py_SyclContextType;
304 this->PySyclContextType_ = &PySyclContextType;
305 this->Py_SyclEventType_ = &Py_SyclEventType;
306 this->PySyclEventType_ = &PySyclEventType;
307 this->Py_SyclQueueType_ = &Py_SyclQueueType;
308 this->PySyclQueueType_ = &PySyclQueueType;
309 this->Py_MemoryType_ = &Py_MemoryType;
310 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
311 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
312 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
313 this->PyUSMArrayType_ = &PyUSMArrayType;
314 this->PySyclProgramType_ = &PySyclProgramType;
315 this->PySyclKernelType_ = &PySyclKernelType;
316
317 // SyclDevice API
318 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
319 this->SyclDevice_Make_ = SyclDevice_Make;
320
321 // SyclContext API
322 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
323 this->SyclContext_Make_ = SyclContext_Make;
324
325 // SyclEvent API
326 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
327 this->SyclEvent_Make_ = SyclEvent_Make;
328
329 // SyclQueue API
330 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
331 this->SyclQueue_Make_ = SyclQueue_Make;
332
333 // dpctl.memory API
334 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
335 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
336 this->Memory_GetContextRef_ = Memory_GetContextRef;
337 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
338 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
339 this->Memory_Make_ = Memory_Make;
340
341 // dpctl.program API
342 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
343 this->SyclKernel_Make_ = SyclKernel_Make;
344 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
345 this->SyclProgram_Make_ = SyclProgram_Make;
346
347 // dpctl.tensor.usm_ndarray API
348 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
349 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
350 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
351 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
352 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
353 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
354 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
355 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
356 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
357 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
358 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
359 this->UsmNDArray_MakeSimpleFromMemory_ =
360 UsmNDArray_MakeSimpleFromMemory;
361 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
362 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
363
364 // constants
365 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
366 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
367 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
368 this->UAR_BOOL_ = UAR_BOOL;
369 this->UAR_BYTE_ = UAR_BYTE;
370 this->UAR_UBYTE_ = UAR_UBYTE;
371 this->UAR_SHORT_ = UAR_SHORT;
372 this->UAR_USHORT_ = UAR_USHORT;
373 this->UAR_INT_ = UAR_INT;
374 this->UAR_UINT_ = UAR_UINT;
375 this->UAR_LONG_ = UAR_LONG;
376 this->UAR_ULONG_ = UAR_ULONG;
377 this->UAR_LONGLONG_ = UAR_LONGLONG;
378 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
379 this->UAR_FLOAT_ = UAR_FLOAT;
380 this->UAR_DOUBLE_ = UAR_DOUBLE;
381 this->UAR_CFLOAT_ = UAR_CFLOAT;
382 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
383 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
384 this->UAR_HALF_ = UAR_HALF;
385
386 // deduced disjoint types
387 this->UAR_INT8_ = UAR_BYTE;
388 this->UAR_UINT8_ = UAR_UBYTE;
389 this->UAR_INT16_ = UAR_SHORT;
390 this->UAR_UINT16_ = UAR_USHORT;
391 this->UAR_INT32_ =
392 platform_typeid_lookup<std::int32_t, long, int, short>(
393 UAR_LONG, UAR_INT, UAR_SHORT);
394 this->UAR_UINT32_ =
395 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
396 unsigned short>(UAR_ULONG, UAR_UINT,
397 UAR_USHORT);
398 this->UAR_INT64_ =
399 platform_typeid_lookup<std::int64_t, long, long long, int>(
400 UAR_LONG, UAR_LONGLONG, UAR_INT);
401 this->UAR_UINT64_ =
402 platform_typeid_lookup<std::uint64_t, unsigned long,
403 unsigned long long, unsigned int>(
404 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
405
406 // create shared pointers to python objects used in type-casters
407 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
408 sycl::queue q_{};
409 PySyclQueueObject *py_q_tmp =
410 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
411 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
412 reinterpret_cast<PyObject *>(py_q_tmp));
413
414 default_sycl_queue_ = std::shared_ptr<py::object>(
415 new py::object(py_sycl_queue), Deleter{});
416
417 py::module_ mod_memory = py::module_::import("dpctl.memory");
418 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
419 as_usm_memory_ = std::shared_ptr<py::object>(
420 new py::object{py_as_usm_memory}, Deleter{});
421
422 auto mem_kl = mod_memory.attr("MemoryUSMHost");
423 const py::object &py_default_usm_memory =
424 mem_kl(1, py::arg("queue") = py_sycl_queue);
425 default_usm_memory_ = std::shared_ptr<py::object>(
426 new py::object{py_default_usm_memory}, Deleter{});
427
428 // TODO: revert to `py::module_::import("dpctl.tensor._usmarray");`
429 // when dpnp fully migrates dpctl/tensor
430 py::module_ mod_usmarray =
431 py::module_::import("dpctl_ext.tensor._usmarray");
432 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
433
434 const py::object &py_default_usm_ndarray =
435 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
436 py::arg("buffer") = py_default_usm_memory);
437
438 default_usm_ndarray_ = std::shared_ptr<py::object>(
439 new py::object{py_default_usm_ndarray}, Deleter{});
440 }
441
442 dpctl_capi(dpctl_capi const &) = default;
443 dpctl_capi &operator=(dpctl_capi const &) = default;
444 dpctl_capi &operator=(dpctl_capi &&) = default;
445
446}; // struct dpctl_capi
447} // namespace detail
448} // namespace dpctl
449
450namespace pybind11::detail
451{
452#define DPCTL_TYPE_CASTER(type, py_name) \
453protected: \
454 std::unique_ptr<type> value; \
455 \
456public: \
457 static constexpr auto name = py_name; \
458 template < \
459 typename T_, \
460 ::pybind11::detail::enable_if_t< \
461 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
462 int> = 0> \
463 static ::pybind11::handle cast(T_ *src, \
464 ::pybind11::return_value_policy policy, \
465 ::pybind11::handle parent) \
466 { \
467 if (!src) \
468 return ::pybind11::none().release(); \
469 if (policy == ::pybind11::return_value_policy::take_ownership) { \
470 auto h = cast(std::move(*src), policy, parent); \
471 delete src; \
472 return h; \
473 } \
474 return cast(*src, policy, parent); \
475 } \
476 operator type *() \
477 { \
478 return value.get(); \
479 } /* NOLINT(bugprone-macro-parentheses) */ \
480 operator type &() \
481 { \
482 return *value; \
483 } /* NOLINT(bugprone-macro-parentheses) */ \
484 operator type &&() && \
485 { \
486 return std::move(*value); \
487 } /* NOLINT(bugprone-macro-parentheses) */ \
488 template <typename T_> \
489 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
490
491/* This type caster associates ``sycl::queue`` C++ class with
492 * :class:`dpctl.SyclQueue` for the purposes of generation of
493 * Python bindings by pybind11.
494 */
495template <>
496struct type_caster<sycl::queue>
497{
498public:
499 bool load(handle src, bool)
500 {
501 PyObject *source = src.ptr();
502 auto const &api = ::dpctl::detail::dpctl_capi::get();
503 if (api.PySyclQueue_Check_(source)) {
504 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
505 reinterpret_cast<PySyclQueueObject *>(source));
506 value = std::make_unique<sycl::queue>(
507 *(reinterpret_cast<sycl::queue *>(QRef)));
508 return true;
509 }
510 else {
511 throw py::type_error(
512 "Input is of unexpected type, expected dpctl.SyclQueue");
513 }
514 }
515
516 static handle cast(sycl::queue src, return_value_policy, handle)
517 {
518 auto const &api = ::dpctl::detail::dpctl_capi::get();
519 auto tmp =
520 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
521 return handle(reinterpret_cast<PyObject *>(tmp));
522 }
523
524 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
525};
526
527/* This type caster associates ``sycl::device`` C++ class with
528 * :class:`dpctl.SyclDevice` for the purposes of generation of
529 * Python bindings by pybind11.
530 */
531template <>
532struct type_caster<sycl::device>
533{
534public:
535 bool load(handle src, bool)
536 {
537 PyObject *source = src.ptr();
538 auto const &api = ::dpctl::detail::dpctl_capi::get();
539 if (api.PySyclDevice_Check_(source)) {
540 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
541 reinterpret_cast<PySyclDeviceObject *>(source));
542 value = std::make_unique<sycl::device>(
543 *(reinterpret_cast<sycl::device *>(DRef)));
544 return true;
545 }
546 else {
547 throw py::type_error(
548 "Input is of unexpected type, expected dpctl.SyclDevice");
549 }
550 }
551
552 static handle cast(sycl::device src, return_value_policy, handle)
553 {
554 auto const &api = ::dpctl::detail::dpctl_capi::get();
555 auto tmp =
556 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
557 return handle(reinterpret_cast<PyObject *>(tmp));
558 }
559
560 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
561};
562
563/* This type caster associates ``sycl::context`` C++ class with
564 * :class:`dpctl.SyclContext` for the purposes of generation of
565 * Python bindings by pybind11.
566 */
567template <>
568struct type_caster<sycl::context>
569{
570public:
571 bool load(handle src, bool)
572 {
573 PyObject *source = src.ptr();
574 auto const &api = ::dpctl::detail::dpctl_capi::get();
575 if (api.PySyclContext_Check_(source)) {
576 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
577 reinterpret_cast<PySyclContextObject *>(source));
578 value = std::make_unique<sycl::context>(
579 *(reinterpret_cast<sycl::context *>(CRef)));
580 return true;
581 }
582 else {
583 throw py::type_error(
584 "Input is of unexpected type, expected dpctl.SyclContext");
585 }
586 }
587
588 static handle cast(sycl::context src, return_value_policy, handle)
589 {
590 auto const &api = ::dpctl::detail::dpctl_capi::get();
591 auto tmp =
592 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
593 return handle(reinterpret_cast<PyObject *>(tmp));
594 }
595
596 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
597};
598
599/* This type caster associates ``sycl::event`` C++ class with
600 * :class:`dpctl.SyclEvent` for the purposes of generation of
601 * Python bindings by pybind11.
602 */
603template <>
604struct type_caster<sycl::event>
605{
606public:
607 bool load(handle src, bool)
608 {
609 PyObject *source = src.ptr();
610 auto const &api = ::dpctl::detail::dpctl_capi::get();
611 if (api.PySyclEvent_Check_(source)) {
612 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
613 reinterpret_cast<PySyclEventObject *>(source));
614 value = std::make_unique<sycl::event>(
615 *(reinterpret_cast<sycl::event *>(ERef)));
616 return true;
617 }
618 else {
619 throw py::type_error(
620 "Input is of unexpected type, expected dpctl.SyclEvent");
621 }
622 }
623
624 static handle cast(sycl::event src, return_value_policy, handle)
625 {
626 auto const &api = ::dpctl::detail::dpctl_capi::get();
627 auto tmp =
628 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
629 return handle(reinterpret_cast<PyObject *>(tmp));
630 }
631
632 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
633};
634
635/* This type caster associates ``sycl::kernel`` C++ class with
636 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
637 * Python bindings by pybind11.
638 */
639template <>
640struct type_caster<sycl::kernel>
641{
642public:
643 bool load(handle src, bool)
644 {
645 PyObject *source = src.ptr();
646 auto const &api = ::dpctl::detail::dpctl_capi::get();
647 if (api.PySyclKernel_Check_(source)) {
648 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
649 reinterpret_cast<PySyclKernelObject *>(source));
650 value = std::make_unique<sycl::kernel>(
651 *(reinterpret_cast<sycl::kernel *>(KRef)));
652 return true;
653 }
654 else {
655 throw py::type_error("Input is of unexpected type, expected "
656 "dpctl.program.SyclKernel");
657 }
658 }
659
660 static handle cast(sycl::kernel src, return_value_policy, handle)
661 {
662 auto const &api = ::dpctl::detail::dpctl_capi::get();
663 auto tmp =
664 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
665 "dpctl4pybind11_kernel");
666 return handle(reinterpret_cast<PyObject *>(tmp));
667 }
668
669 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
670};
671
672/* This type caster associates
673 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
674 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
675 * Python bindings by pybind11.
676 */
677template <>
678struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
679{
680public:
681 bool load(handle src, bool)
682 {
683 PyObject *source = src.ptr();
684 auto const &api = ::dpctl::detail::dpctl_capi::get();
685 if (api.PySyclProgram_Check_(source)) {
686 DPCTLSyclKernelBundleRef KBRef =
687 api.SyclProgram_GetKernelBundleRef_(
688 reinterpret_cast<PySyclProgramObject *>(source));
689 value = std::make_unique<
690 sycl::kernel_bundle<sycl::bundle_state::executable>>(
691 *(reinterpret_cast<
692 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
693 KBRef)));
694 return true;
695 }
696 else {
697 throw py::type_error("Input is of unexpected type, expected "
698 "dpctl.program.SyclProgram");
699 }
700 }
701
702 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
703 return_value_policy,
704 handle)
705 {
706 auto const &api = ::dpctl::detail::dpctl_capi::get();
707 auto tmp = api.SyclProgram_Make_(
708 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
709 return handle(reinterpret_cast<PyObject *>(tmp));
710 }
711
712 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
713 _("dpctl.program.SyclProgram"));
714};
715
716/* This type caster associates
717 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
718 * of generation of Python bindings by pybind11.
719 */
720template <>
721struct type_caster<sycl::half>
722{
723public:
724 bool load(handle src, bool convert)
725 {
726 double py_value;
727
728 if (!src) {
729 return false;
730 }
731
732 PyObject *source = src.ptr();
733
734 if (convert || PyFloat_Check(source)) {
735 py_value = PyFloat_AsDouble(source);
736 }
737 else {
738 return false;
739 }
740
741 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
742
743 if (py_err) {
744 PyErr_Clear();
745 if (convert && (PyNumber_Check(source) != 0)) {
746 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
747 return load(tmp, false);
748 }
749 return false;
750 }
751 value = static_cast<sycl::half>(py_value);
752 return true;
753 }
754
755 static handle cast(sycl::half src, return_value_policy, handle)
756 {
757 return PyFloat_FromDouble(static_cast<double>(src));
758 }
759
760 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
761};
762} // namespace pybind11::detail
763
764namespace dpctl
765{
766namespace memory
767{
768// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
769// this allows to avoid compilation error
770using pybind11::error_already_set;
771
772class usm_memory : public py::object
773{
774public:
775 PYBIND11_OBJECT_CVT(
777 py::object,
778 [](PyObject *o) -> bool {
779 return PyObject_TypeCheck(
780 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
781 0;
782 },
783 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
784
785 usm_memory()
786 : py::object(
787 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
788 borrowed_t{})
789 {
790 if (!m_ptr)
791 throw py::error_already_set();
792 }
793
797 usm_memory(void *usm_ptr,
798 std::size_t nbytes,
799 const sycl::queue &q,
800 std::shared_ptr<void> shptr)
801 {
802 auto const &api = ::dpctl::detail::dpctl_capi::get();
803 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
804 auto q_uptr = std::make_unique<sycl::queue>(q);
805 DPCTLSyclQueueRef QRef =
806 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
807
808 auto vacuous_destructor = []() {};
809 py::capsule mock_owner(vacuous_destructor);
810
811 // create memory object owned by mock_owner, it is a new reference
812 PyObject *_memory =
813 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
814 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
815
816 using py_uptrT =
817 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
818
819 if (!_memory) {
820 throw py::error_already_set();
821 }
822
823 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
824 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
825
826 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
827 // replace mock_owner capsule as the owner
828 memobj->refobj = Py_None;
829 // set opaque ptr field, usm_memory now knowns that USM is managed
830 // by smart pointer
831 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
832
833 // _memory will delete created copies of sycl::queue, and
834 // std::shared_ptr and the deleter of the shared_ptr<void> is
835 // supposed to free the USM allocation
836 m_ptr = _memory;
837 q_uptr.release();
838 memory_uptr.release();
839 }
840
841 sycl::queue get_queue() const
842 {
843 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
844 auto const &api = ::dpctl::detail::dpctl_capi::get();
845 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
846 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
847 return *obj_q;
848 }
849
850 char *get_pointer() const
851 {
852 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
853 auto const &api = ::dpctl::detail::dpctl_capi::get();
854 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
855 return reinterpret_cast<char *>(MRef);
856 }
857
858 std::size_t get_nbytes() const
859 {
860 auto const &api = ::dpctl::detail::dpctl_capi::get();
861 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
862 return api.Memory_GetNumBytes_(mem_obj);
863 }
864
865 bool is_managed_by_smart_ptr() const
866 {
867 auto const &api = ::dpctl::detail::dpctl_capi::get();
868 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
869 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
870
871 return bool(opaque_ptr);
872 }
873
874 const std::shared_ptr<void> &get_smart_ptr_owner() const
875 {
876 auto const &api = ::dpctl::detail::dpctl_capi::get();
877 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
878 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
879
880 if (opaque_ptr) {
881 auto shptr_ptr =
882 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
883 return *shptr_ptr;
884 }
885 else {
886 throw std::runtime_error(
887 "Memory object does not have smart pointer "
888 "managing lifetime of USM allocation");
889 }
890 }
891
892protected:
893 static PyObject *as_usm_memory(PyObject *o)
894 {
895 if (o == nullptr) {
896 PyErr_SetString(PyExc_ValueError,
897 "cannot create a usm_memory from a nullptr");
898 return nullptr;
899 }
900
901 auto converter =
902 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
903
904 py::object res;
905 try {
906 res = converter(py::handle(o));
907 } catch (const py::error_already_set &e) {
908 return nullptr;
909 }
910 return res.ptr();
911 }
912};
913} // end namespace memory
914
915namespace tensor
916{
917inline std::vector<py::ssize_t>
918 c_contiguous_strides(int nd,
919 const py::ssize_t *shape,
920 py::ssize_t element_size = 1)
921{
922 if (nd > 0) {
923 std::vector<py::ssize_t> c_strides(nd, element_size);
924 for (int ic = nd - 1; ic > 0;) {
925 py::ssize_t next_v = c_strides[ic] * shape[ic];
926 c_strides[--ic] = next_v;
927 }
928 return c_strides;
929 }
930 else {
931 return std::vector<py::ssize_t>();
932 }
933}
934
935inline std::vector<py::ssize_t>
936 f_contiguous_strides(int nd,
937 const py::ssize_t *shape,
938 py::ssize_t element_size = 1)
939{
940 if (nd > 0) {
941 std::vector<py::ssize_t> f_strides(nd, element_size);
942 for (int i = 0; i < nd - 1;) {
943 py::ssize_t next_v = f_strides[i] * shape[i];
944 f_strides[++i] = next_v;
945 }
946 return f_strides;
947 }
948 else {
949 return std::vector<py::ssize_t>();
950 }
951}
952
953inline std::vector<py::ssize_t>
954 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
955 py::ssize_t element_size = 1)
956{
957 return c_contiguous_strides(shape.size(), shape.data(), element_size);
958}
959
960inline std::vector<py::ssize_t>
961 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
962 py::ssize_t element_size = 1)
963{
964 return f_contiguous_strides(shape.size(), shape.data(), element_size);
965}
966
967class usm_ndarray : public py::object
968{
969public:
970 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
971 return PyObject_TypeCheck(
972 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
973 })
974
976 : py::object(
977 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
978 borrowed_t{})
979 {
980 if (!m_ptr)
981 throw py::error_already_set();
982 }
983
984 char *get_data() const
985 {
986 PyUSMArrayObject *raw_ar = usm_array_ptr();
987
988 auto const &api = ::dpctl::detail::dpctl_capi::get();
989 return api.UsmNDArray_GetData_(raw_ar);
990 }
991
992 template <typename T>
993 T *get_data() const
994 {
995 return reinterpret_cast<T *>(get_data());
996 }
997
998 int get_ndim() const
999 {
1000 PyUSMArrayObject *raw_ar = usm_array_ptr();
1001
1002 auto const &api = ::dpctl::detail::dpctl_capi::get();
1003 return api.UsmNDArray_GetNDim_(raw_ar);
1004 }
1005
1006 const py::ssize_t *get_shape_raw() const
1007 {
1008 PyUSMArrayObject *raw_ar = usm_array_ptr();
1009
1010 auto const &api = ::dpctl::detail::dpctl_capi::get();
1011 return api.UsmNDArray_GetShape_(raw_ar);
1012 }
1013
1014 std::vector<py::ssize_t> get_shape_vector() const
1015 {
1016 auto raw_sh = get_shape_raw();
1017 auto nd = get_ndim();
1018
1019 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
1020 return shape_vector;
1021 }
1022
1023 py::ssize_t get_shape(int i) const
1024 {
1025 auto shape_ptr = get_shape_raw();
1026 return shape_ptr[i];
1027 }
1028
1029 const py::ssize_t *get_strides_raw() const
1030 {
1031 PyUSMArrayObject *raw_ar = usm_array_ptr();
1032
1033 auto const &api = ::dpctl::detail::dpctl_capi::get();
1034 return api.UsmNDArray_GetStrides_(raw_ar);
1035 }
1036
1037 std::vector<py::ssize_t> get_strides_vector() const
1038 {
1039 auto raw_st = get_strides_raw();
1040 auto nd = get_ndim();
1041
1042 if (raw_st == nullptr) {
1043 auto is_c_contig = is_c_contiguous();
1044 auto is_f_contig = is_f_contiguous();
1045 auto raw_sh = get_shape_raw();
1046 if (is_c_contig) {
1047 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1048 return contig_strides;
1049 }
1050 else if (is_f_contig) {
1051 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1052 return contig_strides;
1053 }
1054 else {
1055 throw std::runtime_error("Invalid array encountered when "
1056 "building strides");
1057 }
1058 }
1059 else {
1060 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1061 return st_vec;
1062 }
1063 }
1064
1065 py::ssize_t get_size() const
1066 {
1067 PyUSMArrayObject *raw_ar = usm_array_ptr();
1068
1069 auto const &api = ::dpctl::detail::dpctl_capi::get();
1070 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
1071 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1072
1073 py::ssize_t nelems = 1;
1074 for (int i = 0; i < ndim; ++i) {
1075 nelems *= shape[i];
1076 }
1077
1078 assert(nelems >= 0);
1079 return nelems;
1080 }
1081
1082 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1083 {
1084 PyUSMArrayObject *raw_ar = usm_array_ptr();
1085
1086 auto const &api = ::dpctl::detail::dpctl_capi::get();
1087 int nd = api.UsmNDArray_GetNDim_(raw_ar);
1088 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1089 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
1090
1091 py::ssize_t offset_min = 0;
1092 py::ssize_t offset_max = 0;
1093 if (strides == nullptr) {
1094 py::ssize_t stride(1);
1095 for (int i = 0; i < nd; ++i) {
1096 offset_max += stride * (shape[i] - 1);
1097 stride *= shape[i];
1098 }
1099 }
1100 else {
1101 for (int i = 0; i < nd; ++i) {
1102 py::ssize_t delta = strides[i] * (shape[i] - 1);
1103 if (strides[i] > 0) {
1104 offset_max += delta;
1105 }
1106 else {
1107 offset_min += delta;
1108 }
1109 }
1110 }
1111 return std::make_pair(offset_min, offset_max);
1112 }
1113
1114 sycl::queue get_queue() const
1115 {
1116 PyUSMArrayObject *raw_ar = usm_array_ptr();
1117
1118 auto const &api = ::dpctl::detail::dpctl_capi::get();
1119 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1120 return *(reinterpret_cast<sycl::queue *>(QRef));
1121 }
1122
1123 sycl::device get_device() const
1124 {
1125 PyUSMArrayObject *raw_ar = usm_array_ptr();
1126
1127 auto const &api = ::dpctl::detail::dpctl_capi::get();
1128 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1129 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1130 }
1131
1132 int get_typenum() const
1133 {
1134 PyUSMArrayObject *raw_ar = usm_array_ptr();
1135
1136 auto const &api = ::dpctl::detail::dpctl_capi::get();
1137 return api.UsmNDArray_GetTypenum_(raw_ar);
1138 }
1139
1140 int get_flags() const
1141 {
1142 PyUSMArrayObject *raw_ar = usm_array_ptr();
1143
1144 auto const &api = ::dpctl::detail::dpctl_capi::get();
1145 return api.UsmNDArray_GetFlags_(raw_ar);
1146 }
1147
1148 int get_elemsize() const
1149 {
1150 PyUSMArrayObject *raw_ar = usm_array_ptr();
1151
1152 auto const &api = ::dpctl::detail::dpctl_capi::get();
1153 return api.UsmNDArray_GetElementSize_(raw_ar);
1154 }
1155
1156 bool is_c_contiguous() const
1157 {
1158 int flags = get_flags();
1159 auto const &api = ::dpctl::detail::dpctl_capi::get();
1160 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1161 }
1162
1163 bool is_f_contiguous() const
1164 {
1165 int flags = get_flags();
1166 auto const &api = ::dpctl::detail::dpctl_capi::get();
1167 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1168 }
1169
1170 bool is_writable() const
1171 {
1172 int flags = get_flags();
1173 auto const &api = ::dpctl::detail::dpctl_capi::get();
1174 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1175 }
1176
1178 py::object get_usm_data() const
1179 {
1180 PyUSMArrayObject *raw_ar = usm_array_ptr();
1181
1182 auto const &api = ::dpctl::detail::dpctl_capi::get();
1183 // UsmNDArray_GetUSMData_ gives a new reference
1184 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1185
1186 // pass reference ownership to py::object
1187 return py::reinterpret_steal<py::object>(usm_data);
1188 }
1189
1190 bool is_managed_by_smart_ptr() const
1191 {
1192 PyUSMArrayObject *raw_ar = usm_array_ptr();
1193
1194 auto const &api = ::dpctl::detail::dpctl_capi::get();
1195 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1196
1197 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1198 Py_DECREF(usm_data);
1199 return false;
1200 }
1201
1202 Py_MemoryObject *mem_obj =
1203 reinterpret_cast<Py_MemoryObject *>(usm_data);
1204 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1205
1206 Py_DECREF(usm_data);
1207 return bool(opaque_ptr);
1208 }
1209
1210 const std::shared_ptr<void> &get_smart_ptr_owner() const
1211 {
1212 PyUSMArrayObject *raw_ar = usm_array_ptr();
1213
1214 auto const &api = ::dpctl::detail::dpctl_capi::get();
1215
1216 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1217
1218 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1219 Py_DECREF(usm_data);
1220 throw std::runtime_error(
1221 "usm_ndarray object does not have Memory object "
1222 "managing lifetime of USM allocation");
1223 }
1224
1225 Py_MemoryObject *mem_obj =
1226 reinterpret_cast<Py_MemoryObject *>(usm_data);
1227 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1228 Py_DECREF(usm_data);
1229
1230 if (opaque_ptr) {
1231 auto shptr_ptr =
1232 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1233 return *shptr_ptr;
1234 }
1235 else {
1236 throw std::runtime_error(
1237 "Memory object underlying usm_ndarray does not have "
1238 "smart pointer managing lifetime of USM allocation");
1239 }
1240 }
1241
1242private:
1243 PyUSMArrayObject *usm_array_ptr() const
1244 {
1245 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1246 }
1247};
1248} // end namespace tensor
1249
1250namespace utils
1251{
1252namespace detail
1253{
1255{
1256
1257 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1258 {
1259 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1260 const auto &usm_memory_inst =
1261 py::cast<dpctl::memory::usm_memory>(h);
1262 return usm_memory_inst.is_managed_by_smart_ptr();
1263 }
1264 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1265 const auto &usm_array_inst =
1266 py::cast<dpctl::tensor::usm_ndarray>(h);
1267 return usm_array_inst.is_managed_by_smart_ptr();
1268 }
1269
1270 return false;
1271 }
1272
1273 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1274 {
1275 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1276 const auto &usm_memory_inst =
1277 py::cast<dpctl::memory::usm_memory>(h);
1278 return usm_memory_inst.get_smart_ptr_owner();
1279 }
1280 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1281 const auto &usm_array_inst =
1282 py::cast<dpctl::tensor::usm_ndarray>(h);
1283 return usm_array_inst.get_smart_ptr_owner();
1284 }
1285
1286 throw std::runtime_error(
1287 "Attempted extraction of shared_ptr on an unrecognized type");
1288 }
1289};
1290} // end of namespace detail
1291
1292template <std::size_t num>
1293sycl::event keep_args_alive(sycl::queue &q,
1294 const py::object (&py_objs)[num],
1295 const std::vector<sycl::event> &depends = {})
1296{
1297 std::size_t n_objects_held = 0;
1298 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1299
1300 std::size_t n_usm_owners_held = 0;
1301 std::array<std::shared_ptr<void>, num> shp_usm{};
1302
1303 for (std::size_t i = 0; i < num; ++i) {
1304 const auto &py_obj_i = py_objs[i];
1305 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1306 const auto &shp =
1307 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1308 shp_usm[n_usm_owners_held] = shp;
1309 ++n_usm_owners_held;
1310 }
1311 else {
1312 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1313 shp_arr[n_objects_held]->inc_ref();
1314 ++n_objects_held;
1315 }
1316 }
1317
1318 bool use_depends = true;
1319 sycl::event host_task_ev;
1320
1321 if (n_usm_owners_held > 0) {
1322 host_task_ev = q.submit([&](sycl::handler &cgh) {
1323 if (use_depends) {
1324 cgh.depends_on(depends);
1325 use_depends = false;
1326 }
1327 else {
1328 cgh.depends_on(host_task_ev);
1329 }
1330 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1331 // no body, but shared pointers are captured in
1332 // the lambda, ensuring that USM allocation is
1333 // kept alive
1334 });
1335 });
1336 }
1337
1338 if (n_objects_held > 0) {
1339 host_task_ev = q.submit([&](sycl::handler &cgh) {
1340 if (use_depends) {
1341 cgh.depends_on(depends);
1342 use_depends = false;
1343 }
1344 else {
1345 cgh.depends_on(host_task_ev);
1346 }
1347 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1348 py::gil_scoped_acquire acquire;
1349
1350 for (std::size_t i = 0; i < n_objects_held; ++i) {
1351 shp_arr[i]->dec_ref();
1352 }
1353 });
1354 });
1355 }
1356
1357 return host_task_ev;
1358}
1359
1362template <std::size_t num>
1363bool queues_are_compatible(const sycl::queue &exec_q,
1364 const sycl::queue (&alloc_qs)[num])
1365{
1366 for (std::size_t i = 0; i < num; ++i) {
1367
1368 if (exec_q != alloc_qs[i]) {
1369 return false;
1370 }
1371 }
1372 return true;
1373}
1374
1377template <std::size_t num>
1378bool queues_are_compatible(const sycl::queue &exec_q,
1379 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1380{
1381 for (std::size_t i = 0; i < num; ++i) {
1382
1383 if (exec_q != arrs[i].get_queue()) {
1384 return false;
1385 }
1386 }
1387 return true;
1388}
1389} // end namespace utils
1390} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.