DPNP C++ backend kernel library 0.20.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include "dpctl_capi.h"
32
33#include <complex>
34#include <cstddef> // for std::size_t for C++ linkage
35#include <memory>
36#include <stddef.h> // for size_t for C linkage
37#include <stdexcept>
38#include <utility>
39#include <vector>
40
41#include <pybind11/pybind11.h>
42
43#include <sycl/sycl.hpp>
44
45namespace py = pybind11;
46
47namespace dpctl
48{
49namespace detail
50{
51// Lookup a type according to its size, and return a value corresponding to the
52// NumPy typenum.
53template <typename Concrete>
54constexpr int platform_typeid_lookup()
55{
56 return -1;
57}
58
59template <typename Concrete, typename T, typename... Ts, typename... Ints>
60constexpr int platform_typeid_lookup(int I, Ints... Is)
61{
62 return sizeof(Concrete) == sizeof(T)
63 ? I
64 : platform_typeid_lookup<Concrete, Ts...>(Is...);
65}
66
68{
69public:
70 // dpctl type objects
71 PyTypeObject *Py_SyclDeviceType_;
72 PyTypeObject *PySyclDeviceType_;
73 PyTypeObject *Py_SyclContextType_;
74 PyTypeObject *PySyclContextType_;
75 PyTypeObject *Py_SyclEventType_;
76 PyTypeObject *PySyclEventType_;
77 PyTypeObject *Py_SyclQueueType_;
78 PyTypeObject *PySyclQueueType_;
79 PyTypeObject *Py_MemoryType_;
80 PyTypeObject *PyMemoryUSMDeviceType_;
81 PyTypeObject *PyMemoryUSMSharedType_;
82 PyTypeObject *PyMemoryUSMHostType_;
83 PyTypeObject *PyUSMArrayType_;
84 PyTypeObject *PySyclProgramType_;
85 PyTypeObject *PySyclKernelType_;
86
87 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
88 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
89
90 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
91 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
92
93 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
94 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
95
96 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
97 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
98
99 // memory
100 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
101 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
102 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
103 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
104 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
105 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
106 size_t,
107 DPCTLSyclQueueRef,
108 PyObject *);
109
110 // program
111 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
112 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
113
114 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
115 PySyclProgramObject *);
116 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
117
118 // tensor
119 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
120 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
121 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
122 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
123 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
124 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
125 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
126 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
127 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
128 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
129 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
130 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
131 const py::ssize_t *,
132 int,
133 Py_MemoryObject *,
134 py::ssize_t,
135 char);
136 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
137 int,
138 DPCTLSyclUSMRef,
139 DPCTLSyclQueueRef,
140 PyObject *);
141 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
142 const py::ssize_t *,
143 int,
144 const py::ssize_t *,
145 DPCTLSyclUSMRef,
146 DPCTLSyclQueueRef,
147 py::ssize_t,
148 PyObject *);
149
150 int USM_ARRAY_C_CONTIGUOUS_;
151 int USM_ARRAY_F_CONTIGUOUS_;
152 int USM_ARRAY_WRITABLE_;
153 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
154 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
155 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
156 UAR_HALF_;
157 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
158 UAR_INT64_, UAR_UINT64_;
159
160 bool PySyclDevice_Check_(PyObject *obj) const
161 {
162 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
163 }
164 bool PySyclContext_Check_(PyObject *obj) const
165 {
166 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
167 }
168 bool PySyclEvent_Check_(PyObject *obj) const
169 {
170 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
171 }
172 bool PySyclQueue_Check_(PyObject *obj) const
173 {
174 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
175 }
176 bool PySyclKernel_Check_(PyObject *obj) const
177 {
178 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
179 }
180 bool PySyclProgram_Check_(PyObject *obj) const
181 {
182 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
183 }
184
186 {
187 as_usm_memory_.reset();
188 default_usm_ndarray_.reset();
189 default_usm_memory_.reset();
190 default_sycl_queue_.reset();
191 };
192
193 static auto &get()
194 {
195 static dpctl_capi api{};
196 return api;
197 }
198
199 py::object default_sycl_queue_pyobj()
200 {
201 return *default_sycl_queue_;
202 }
203 py::object default_usm_memory_pyobj()
204 {
205 return *default_usm_memory_;
206 }
207 py::object default_usm_ndarray_pyobj()
208 {
209 return *default_usm_ndarray_;
210 }
211 py::object as_usm_memory_pyobj()
212 {
213 return *as_usm_memory_;
214 }
215
216private:
217 struct Deleter
218 {
219 void operator()(py::object *p) const
220 {
221 const bool initialized = Py_IsInitialized();
222#if PY_VERSION_HEX < 0x30d0000
223 const bool finalizing = _Py_IsFinalizing();
224#else
225 const bool finalizing = Py_IsFinalizing();
226#endif
227 const bool guard = initialized && !finalizing;
228
229 if (guard) {
230 delete p;
231 }
232 }
233 };
234
235 std::shared_ptr<py::object> default_sycl_queue_;
236 std::shared_ptr<py::object> default_usm_memory_;
237 std::shared_ptr<py::object> default_usm_ndarray_;
238 std::shared_ptr<py::object> as_usm_memory_;
239
240 dpctl_capi()
241 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
242 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
243 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
244 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
245 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
246 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
247 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
248 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
249 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
250 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
251 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
252 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
253 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
254 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
255 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
256 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
257 SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
258 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
259 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
260 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
261 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
262 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
263 UsmNDArray_MakeSimpleFromMemory_(nullptr),
264 UsmNDArray_MakeSimpleFromPtr_(nullptr),
265 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
266 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
267 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
268 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
269 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
270 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
271 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
272 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
273 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
274 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
275
276 {
277 // Import Cython-generated C-API for dpctl
278 // This imports python modules and initializes
279 // static variables such as function pointers for C-API,
280 // e.g. SyclDevice_GetDeviceRef, etc.
281 // pointers to Python types, i.e. PySyclDeviceType, etc.
282 // and exported constants, i.e. USM_ARRAY_C_CONTIGUOUS, etc.
283 import_dpctl();
284
285 // Python type objects for classes implemented by dpctl
286 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
287 this->PySyclDeviceType_ = &PySyclDeviceType;
288 this->Py_SyclContextType_ = &Py_SyclContextType;
289 this->PySyclContextType_ = &PySyclContextType;
290 this->Py_SyclEventType_ = &Py_SyclEventType;
291 this->PySyclEventType_ = &PySyclEventType;
292 this->Py_SyclQueueType_ = &Py_SyclQueueType;
293 this->PySyclQueueType_ = &PySyclQueueType;
294 this->Py_MemoryType_ = &Py_MemoryType;
295 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
296 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
297 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
298 this->PyUSMArrayType_ = &PyUSMArrayType;
299 this->PySyclProgramType_ = &PySyclProgramType;
300 this->PySyclKernelType_ = &PySyclKernelType;
301
302 // SyclDevice API
303 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
304 this->SyclDevice_Make_ = SyclDevice_Make;
305
306 // SyclContext API
307 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
308 this->SyclContext_Make_ = SyclContext_Make;
309
310 // SyclEvent API
311 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
312 this->SyclEvent_Make_ = SyclEvent_Make;
313
314 // SyclQueue API
315 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
316 this->SyclQueue_Make_ = SyclQueue_Make;
317
318 // dpctl.memory API
319 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
320 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
321 this->Memory_GetContextRef_ = Memory_GetContextRef;
322 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
323 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
324 this->Memory_Make_ = Memory_Make;
325
326 // dpctl.program API
327 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
328 this->SyclKernel_Make_ = SyclKernel_Make;
329 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
330 this->SyclProgram_Make_ = SyclProgram_Make;
331
332 // dpctl.tensor.usm_ndarray API
333 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
334 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
335 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
336 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
337 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
338 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
339 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
340 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
341 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
342 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
343 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
344 this->UsmNDArray_MakeSimpleFromMemory_ =
345 UsmNDArray_MakeSimpleFromMemory;
346 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
347 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
348
349 // constants
350 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
351 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
352 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
353 this->UAR_BOOL_ = UAR_BOOL;
354 this->UAR_BYTE_ = UAR_BYTE;
355 this->UAR_UBYTE_ = UAR_UBYTE;
356 this->UAR_SHORT_ = UAR_SHORT;
357 this->UAR_USHORT_ = UAR_USHORT;
358 this->UAR_INT_ = UAR_INT;
359 this->UAR_UINT_ = UAR_UINT;
360 this->UAR_LONG_ = UAR_LONG;
361 this->UAR_ULONG_ = UAR_ULONG;
362 this->UAR_LONGLONG_ = UAR_LONGLONG;
363 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
364 this->UAR_FLOAT_ = UAR_FLOAT;
365 this->UAR_DOUBLE_ = UAR_DOUBLE;
366 this->UAR_CFLOAT_ = UAR_CFLOAT;
367 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
368 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
369 this->UAR_HALF_ = UAR_HALF;
370
371 // deduced disjoint types
372 this->UAR_INT8_ = UAR_BYTE;
373 this->UAR_UINT8_ = UAR_UBYTE;
374 this->UAR_INT16_ = UAR_SHORT;
375 this->UAR_UINT16_ = UAR_USHORT;
376 this->UAR_INT32_ =
377 platform_typeid_lookup<std::int32_t, long, int, short>(
378 UAR_LONG, UAR_INT, UAR_SHORT);
379 this->UAR_UINT32_ =
380 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
381 unsigned short>(UAR_ULONG, UAR_UINT,
382 UAR_USHORT);
383 this->UAR_INT64_ =
384 platform_typeid_lookup<std::int64_t, long, long long, int>(
385 UAR_LONG, UAR_LONGLONG, UAR_INT);
386 this->UAR_UINT64_ =
387 platform_typeid_lookup<std::uint64_t, unsigned long,
388 unsigned long long, unsigned int>(
389 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
390
391 // create shared pointers to python objects used in type-casters
392 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
393 sycl::queue q_{};
394 PySyclQueueObject *py_q_tmp =
395 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
396 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
397 reinterpret_cast<PyObject *>(py_q_tmp));
398
399 default_sycl_queue_ = std::shared_ptr<py::object>(
400 new py::object(py_sycl_queue), Deleter{});
401
402 py::module_ mod_memory = py::module_::import("dpctl.memory");
403 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
404 as_usm_memory_ = std::shared_ptr<py::object>(
405 new py::object{py_as_usm_memory}, Deleter{});
406
407 auto mem_kl = mod_memory.attr("MemoryUSMHost");
408 const py::object &py_default_usm_memory =
409 mem_kl(1, py::arg("queue") = py_sycl_queue);
410 default_usm_memory_ = std::shared_ptr<py::object>(
411 new py::object{py_default_usm_memory}, Deleter{});
412
413 py::module_ mod_usmarray =
414 py::module_::import("dpctl.tensor._usmarray");
415 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
416
417 const py::object &py_default_usm_ndarray =
418 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
419 py::arg("buffer") = py_default_usm_memory);
420
421 default_usm_ndarray_ = std::shared_ptr<py::object>(
422 new py::object{py_default_usm_ndarray}, Deleter{});
423 }
424
425 dpctl_capi(dpctl_capi const &) = default;
426 dpctl_capi &operator=(dpctl_capi const &) = default;
427 dpctl_capi &operator=(dpctl_capi &&) = default;
428
429}; // struct dpctl_capi
430} // namespace detail
431} // namespace dpctl
432
433namespace pybind11::detail
434{
435#define DPCTL_TYPE_CASTER(type, py_name) \
436protected: \
437 std::unique_ptr<type> value; \
438 \
439public: \
440 static constexpr auto name = py_name; \
441 template < \
442 typename T_, \
443 ::pybind11::detail::enable_if_t< \
444 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
445 int> = 0> \
446 static ::pybind11::handle cast(T_ *src, \
447 ::pybind11::return_value_policy policy, \
448 ::pybind11::handle parent) \
449 { \
450 if (!src) \
451 return ::pybind11::none().release(); \
452 if (policy == ::pybind11::return_value_policy::take_ownership) { \
453 auto h = cast(std::move(*src), policy, parent); \
454 delete src; \
455 return h; \
456 } \
457 return cast(*src, policy, parent); \
458 } \
459 operator type *() \
460 { \
461 return value.get(); \
462 } /* NOLINT(bugprone-macro-parentheses) */ \
463 operator type &() \
464 { \
465 return *value; \
466 } /* NOLINT(bugprone-macro-parentheses) */ \
467 operator type &&() && \
468 { \
469 return std::move(*value); \
470 } /* NOLINT(bugprone-macro-parentheses) */ \
471 template <typename T_> \
472 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
473
474/* This type caster associates ``sycl::queue`` C++ class with
475 * :class:`dpctl.SyclQueue` for the purposes of generation of
476 * Python bindings by pybind11.
477 */
478template <>
479struct type_caster<sycl::queue>
480{
481public:
482 bool load(handle src, bool)
483 {
484 PyObject *source = src.ptr();
485 auto const &api = ::dpctl::detail::dpctl_capi::get();
486 if (api.PySyclQueue_Check_(source)) {
487 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
488 reinterpret_cast<PySyclQueueObject *>(source));
489 value = std::make_unique<sycl::queue>(
490 *(reinterpret_cast<sycl::queue *>(QRef)));
491 return true;
492 }
493 else {
494 throw py::type_error(
495 "Input is of unexpected type, expected dpctl.SyclQueue");
496 }
497 }
498
499 static handle cast(sycl::queue src, return_value_policy, handle)
500 {
501 auto const &api = ::dpctl::detail::dpctl_capi::get();
502 auto tmp =
503 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
504 return handle(reinterpret_cast<PyObject *>(tmp));
505 }
506
507 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
508};
509
510/* This type caster associates ``sycl::device`` C++ class with
511 * :class:`dpctl.SyclDevice` for the purposes of generation of
512 * Python bindings by pybind11.
513 */
514template <>
515struct type_caster<sycl::device>
516{
517public:
518 bool load(handle src, bool)
519 {
520 PyObject *source = src.ptr();
521 auto const &api = ::dpctl::detail::dpctl_capi::get();
522 if (api.PySyclDevice_Check_(source)) {
523 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
524 reinterpret_cast<PySyclDeviceObject *>(source));
525 value = std::make_unique<sycl::device>(
526 *(reinterpret_cast<sycl::device *>(DRef)));
527 return true;
528 }
529 else {
530 throw py::type_error(
531 "Input is of unexpected type, expected dpctl.SyclDevice");
532 }
533 }
534
535 static handle cast(sycl::device src, return_value_policy, handle)
536 {
537 auto const &api = ::dpctl::detail::dpctl_capi::get();
538 auto tmp =
539 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
540 return handle(reinterpret_cast<PyObject *>(tmp));
541 }
542
543 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
544};
545
546/* This type caster associates ``sycl::context`` C++ class with
547 * :class:`dpctl.SyclContext` for the purposes of generation of
548 * Python bindings by pybind11.
549 */
550template <>
551struct type_caster<sycl::context>
552{
553public:
554 bool load(handle src, bool)
555 {
556 PyObject *source = src.ptr();
557 auto const &api = ::dpctl::detail::dpctl_capi::get();
558 if (api.PySyclContext_Check_(source)) {
559 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
560 reinterpret_cast<PySyclContextObject *>(source));
561 value = std::make_unique<sycl::context>(
562 *(reinterpret_cast<sycl::context *>(CRef)));
563 return true;
564 }
565 else {
566 throw py::type_error(
567 "Input is of unexpected type, expected dpctl.SyclContext");
568 }
569 }
570
571 static handle cast(sycl::context src, return_value_policy, handle)
572 {
573 auto const &api = ::dpctl::detail::dpctl_capi::get();
574 auto tmp =
575 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
576 return handle(reinterpret_cast<PyObject *>(tmp));
577 }
578
579 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
580};
581
582/* This type caster associates ``sycl::event`` C++ class with
583 * :class:`dpctl.SyclEvent` for the purposes of generation of
584 * Python bindings by pybind11.
585 */
586template <>
587struct type_caster<sycl::event>
588{
589public:
590 bool load(handle src, bool)
591 {
592 PyObject *source = src.ptr();
593 auto const &api = ::dpctl::detail::dpctl_capi::get();
594 if (api.PySyclEvent_Check_(source)) {
595 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
596 reinterpret_cast<PySyclEventObject *>(source));
597 value = std::make_unique<sycl::event>(
598 *(reinterpret_cast<sycl::event *>(ERef)));
599 return true;
600 }
601 else {
602 throw py::type_error(
603 "Input is of unexpected type, expected dpctl.SyclEvent");
604 }
605 }
606
607 static handle cast(sycl::event src, return_value_policy, handle)
608 {
609 auto const &api = ::dpctl::detail::dpctl_capi::get();
610 auto tmp =
611 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
612 return handle(reinterpret_cast<PyObject *>(tmp));
613 }
614
615 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
616};
617
618/* This type caster associates ``sycl::kernel`` C++ class with
619 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
620 * Python bindings by pybind11.
621 */
622template <>
623struct type_caster<sycl::kernel>
624{
625public:
626 bool load(handle src, bool)
627 {
628 PyObject *source = src.ptr();
629 auto const &api = ::dpctl::detail::dpctl_capi::get();
630 if (api.PySyclKernel_Check_(source)) {
631 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
632 reinterpret_cast<PySyclKernelObject *>(source));
633 value = std::make_unique<sycl::kernel>(
634 *(reinterpret_cast<sycl::kernel *>(KRef)));
635 return true;
636 }
637 else {
638 throw py::type_error("Input is of unexpected type, expected "
639 "dpctl.program.SyclKernel");
640 }
641 }
642
643 static handle cast(sycl::kernel src, return_value_policy, handle)
644 {
645 auto const &api = ::dpctl::detail::dpctl_capi::get();
646 auto tmp =
647 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
648 "dpctl4pybind11_kernel");
649 return handle(reinterpret_cast<PyObject *>(tmp));
650 }
651
652 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
653};
654
655/* This type caster associates
656 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
657 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
658 * Python bindings by pybind11.
659 */
660template <>
661struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
662{
663public:
664 bool load(handle src, bool)
665 {
666 PyObject *source = src.ptr();
667 auto const &api = ::dpctl::detail::dpctl_capi::get();
668 if (api.PySyclProgram_Check_(source)) {
669 DPCTLSyclKernelBundleRef KBRef =
670 api.SyclProgram_GetKernelBundleRef_(
671 reinterpret_cast<PySyclProgramObject *>(source));
672 value = std::make_unique<
673 sycl::kernel_bundle<sycl::bundle_state::executable>>(
674 *(reinterpret_cast<
675 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
676 KBRef)));
677 return true;
678 }
679 else {
680 throw py::type_error("Input is of unexpected type, expected "
681 "dpctl.program.SyclProgram");
682 }
683 }
684
685 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
686 return_value_policy,
687 handle)
688 {
689 auto const &api = ::dpctl::detail::dpctl_capi::get();
690 auto tmp = api.SyclProgram_Make_(
691 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
692 return handle(reinterpret_cast<PyObject *>(tmp));
693 }
694
695 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
696 _("dpctl.program.SyclProgram"));
697};
698
699/* This type caster associates
700 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
701 * of generation of Python bindings by pybind11.
702 */
703template <>
704struct type_caster<sycl::half>
705{
706public:
707 bool load(handle src, bool convert)
708 {
709 double py_value;
710
711 if (!src) {
712 return false;
713 }
714
715 PyObject *source = src.ptr();
716
717 if (convert || PyFloat_Check(source)) {
718 py_value = PyFloat_AsDouble(source);
719 }
720 else {
721 return false;
722 }
723
724 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
725
726 if (py_err) {
727 PyErr_Clear();
728 if (convert && (PyNumber_Check(source) != 0)) {
729 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
730 return load(tmp, false);
731 }
732 return false;
733 }
734 value = static_cast<sycl::half>(py_value);
735 return true;
736 }
737
738 static handle cast(sycl::half src, return_value_policy, handle)
739 {
740 return PyFloat_FromDouble(static_cast<double>(src));
741 }
742
743 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
744};
745} // namespace pybind11::detail
746
747namespace dpctl
748{
749namespace memory
750{
751// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
752// this allows to avoid compilation error
753using pybind11::error_already_set;
754
755class usm_memory : public py::object
756{
757public:
758 PYBIND11_OBJECT_CVT(
760 py::object,
761 [](PyObject *o) -> bool {
762 return PyObject_TypeCheck(
763 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
764 0;
765 },
766 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
767
768 usm_memory()
769 : py::object(
770 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
771 borrowed_t{})
772 {
773 if (!m_ptr)
774 throw py::error_already_set();
775 }
776
780 usm_memory(void *usm_ptr,
781 std::size_t nbytes,
782 const sycl::queue &q,
783 std::shared_ptr<void> shptr)
784 {
785 auto const &api = ::dpctl::detail::dpctl_capi::get();
786 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
787 auto q_uptr = std::make_unique<sycl::queue>(q);
788 DPCTLSyclQueueRef QRef =
789 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
790
791 auto vacuous_destructor = []() {};
792 py::capsule mock_owner(vacuous_destructor);
793
794 // create memory object owned by mock_owner, it is a new reference
795 PyObject *_memory =
796 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
797 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
798
799 using py_uptrT =
800 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
801
802 if (!_memory) {
803 throw py::error_already_set();
804 }
805
806 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
807 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
808
809 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
810 // replace mock_owner capsule as the owner
811 memobj->refobj = Py_None;
812 // set opaque ptr field, usm_memory now knowns that USM is managed
813 // by smart pointer
814 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
815
816 // _memory will delete created copies of sycl::queue, and
817 // std::shared_ptr and the deleter of the shared_ptr<void> is
818 // supposed to free the USM allocation
819 m_ptr = _memory;
820 q_uptr.release();
821 memory_uptr.release();
822 }
823
824 sycl::queue get_queue() const
825 {
826 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
827 auto const &api = ::dpctl::detail::dpctl_capi::get();
828 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
829 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
830 return *obj_q;
831 }
832
833 char *get_pointer() const
834 {
835 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
836 auto const &api = ::dpctl::detail::dpctl_capi::get();
837 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
838 return reinterpret_cast<char *>(MRef);
839 }
840
841 std::size_t get_nbytes() const
842 {
843 auto const &api = ::dpctl::detail::dpctl_capi::get();
844 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
845 return api.Memory_GetNumBytes_(mem_obj);
846 }
847
848 bool is_managed_by_smart_ptr() const
849 {
850 auto const &api = ::dpctl::detail::dpctl_capi::get();
851 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
852 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
853
854 return bool(opaque_ptr);
855 }
856
857 const std::shared_ptr<void> &get_smart_ptr_owner() const
858 {
859 auto const &api = ::dpctl::detail::dpctl_capi::get();
860 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
861 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
862
863 if (opaque_ptr) {
864 auto shptr_ptr =
865 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
866 return *shptr_ptr;
867 }
868 else {
869 throw std::runtime_error(
870 "Memory object does not have smart pointer "
871 "managing lifetime of USM allocation");
872 }
873 }
874
875protected:
876 static PyObject *as_usm_memory(PyObject *o)
877 {
878 if (o == nullptr) {
879 PyErr_SetString(PyExc_ValueError,
880 "cannot create a usm_memory from a nullptr");
881 return nullptr;
882 }
883
884 auto converter =
885 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
886
887 py::object res;
888 try {
889 res = converter(py::handle(o));
890 } catch (const py::error_already_set &e) {
891 return nullptr;
892 }
893 return res.ptr();
894 }
895};
896} // end namespace memory
897
898namespace tensor
899{
900inline std::vector<py::ssize_t>
901 c_contiguous_strides(int nd,
902 const py::ssize_t *shape,
903 py::ssize_t element_size = 1)
904{
905 if (nd > 0) {
906 std::vector<py::ssize_t> c_strides(nd, element_size);
907 for (int ic = nd - 1; ic > 0;) {
908 py::ssize_t next_v = c_strides[ic] * shape[ic];
909 c_strides[--ic] = next_v;
910 }
911 return c_strides;
912 }
913 else {
914 return std::vector<py::ssize_t>();
915 }
916}
917
918inline std::vector<py::ssize_t>
919 f_contiguous_strides(int nd,
920 const py::ssize_t *shape,
921 py::ssize_t element_size = 1)
922{
923 if (nd > 0) {
924 std::vector<py::ssize_t> f_strides(nd, element_size);
925 for (int i = 0; i < nd - 1;) {
926 py::ssize_t next_v = f_strides[i] * shape[i];
927 f_strides[++i] = next_v;
928 }
929 return f_strides;
930 }
931 else {
932 return std::vector<py::ssize_t>();
933 }
934}
935
936inline std::vector<py::ssize_t>
937 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
938 py::ssize_t element_size = 1)
939{
940 return c_contiguous_strides(shape.size(), shape.data(), element_size);
941}
942
943inline std::vector<py::ssize_t>
944 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
945 py::ssize_t element_size = 1)
946{
947 return f_contiguous_strides(shape.size(), shape.data(), element_size);
948}
949
950class usm_ndarray : public py::object
951{
952public:
953 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
954 return PyObject_TypeCheck(
955 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
956 })
957
959 : py::object(
960 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
961 borrowed_t{})
962 {
963 if (!m_ptr)
964 throw py::error_already_set();
965 }
966
967 char *get_data() const
968 {
969 PyUSMArrayObject *raw_ar = usm_array_ptr();
970
971 auto const &api = ::dpctl::detail::dpctl_capi::get();
972 return api.UsmNDArray_GetData_(raw_ar);
973 }
974
975 template <typename T>
976 T *get_data() const
977 {
978 return reinterpret_cast<T *>(get_data());
979 }
980
981 int get_ndim() const
982 {
983 PyUSMArrayObject *raw_ar = usm_array_ptr();
984
985 auto const &api = ::dpctl::detail::dpctl_capi::get();
986 return api.UsmNDArray_GetNDim_(raw_ar);
987 }
988
989 const py::ssize_t *get_shape_raw() const
990 {
991 PyUSMArrayObject *raw_ar = usm_array_ptr();
992
993 auto const &api = ::dpctl::detail::dpctl_capi::get();
994 return api.UsmNDArray_GetShape_(raw_ar);
995 }
996
997 std::vector<py::ssize_t> get_shape_vector() const
998 {
999 auto raw_sh = get_shape_raw();
1000 auto nd = get_ndim();
1001
1002 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
1003 return shape_vector;
1004 }
1005
1006 py::ssize_t get_shape(int i) const
1007 {
1008 auto shape_ptr = get_shape_raw();
1009 return shape_ptr[i];
1010 }
1011
1012 const py::ssize_t *get_strides_raw() const
1013 {
1014 PyUSMArrayObject *raw_ar = usm_array_ptr();
1015
1016 auto const &api = ::dpctl::detail::dpctl_capi::get();
1017 return api.UsmNDArray_GetStrides_(raw_ar);
1018 }
1019
1020 std::vector<py::ssize_t> get_strides_vector() const
1021 {
1022 auto raw_st = get_strides_raw();
1023 auto nd = get_ndim();
1024
1025 if (raw_st == nullptr) {
1026 auto is_c_contig = is_c_contiguous();
1027 auto is_f_contig = is_f_contiguous();
1028 auto raw_sh = get_shape_raw();
1029 if (is_c_contig) {
1030 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1031 return contig_strides;
1032 }
1033 else if (is_f_contig) {
1034 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1035 return contig_strides;
1036 }
1037 else {
1038 throw std::runtime_error("Invalid array encountered when "
1039 "building strides");
1040 }
1041 }
1042 else {
1043 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1044 return st_vec;
1045 }
1046 }
1047
1048 py::ssize_t get_size() const
1049 {
1050 PyUSMArrayObject *raw_ar = usm_array_ptr();
1051
1052 auto const &api = ::dpctl::detail::dpctl_capi::get();
1053 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
1054 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1055
1056 py::ssize_t nelems = 1;
1057 for (int i = 0; i < ndim; ++i) {
1058 nelems *= shape[i];
1059 }
1060
1061 assert(nelems >= 0);
1062 return nelems;
1063 }
1064
1065 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1066 {
1067 PyUSMArrayObject *raw_ar = usm_array_ptr();
1068
1069 auto const &api = ::dpctl::detail::dpctl_capi::get();
1070 int nd = api.UsmNDArray_GetNDim_(raw_ar);
1071 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1072 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
1073
1074 py::ssize_t offset_min = 0;
1075 py::ssize_t offset_max = 0;
1076 if (strides == nullptr) {
1077 py::ssize_t stride(1);
1078 for (int i = 0; i < nd; ++i) {
1079 offset_max += stride * (shape[i] - 1);
1080 stride *= shape[i];
1081 }
1082 }
1083 else {
1084 for (int i = 0; i < nd; ++i) {
1085 py::ssize_t delta = strides[i] * (shape[i] - 1);
1086 if (strides[i] > 0) {
1087 offset_max += delta;
1088 }
1089 else {
1090 offset_min += delta;
1091 }
1092 }
1093 }
1094 return std::make_pair(offset_min, offset_max);
1095 }
1096
1097 sycl::queue get_queue() const
1098 {
1099 PyUSMArrayObject *raw_ar = usm_array_ptr();
1100
1101 auto const &api = ::dpctl::detail::dpctl_capi::get();
1102 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1103 return *(reinterpret_cast<sycl::queue *>(QRef));
1104 }
1105
1106 sycl::device get_device() const
1107 {
1108 PyUSMArrayObject *raw_ar = usm_array_ptr();
1109
1110 auto const &api = ::dpctl::detail::dpctl_capi::get();
1111 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1112 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1113 }
1114
1115 int get_typenum() const
1116 {
1117 PyUSMArrayObject *raw_ar = usm_array_ptr();
1118
1119 auto const &api = ::dpctl::detail::dpctl_capi::get();
1120 return api.UsmNDArray_GetTypenum_(raw_ar);
1121 }
1122
1123 int get_flags() const
1124 {
1125 PyUSMArrayObject *raw_ar = usm_array_ptr();
1126
1127 auto const &api = ::dpctl::detail::dpctl_capi::get();
1128 return api.UsmNDArray_GetFlags_(raw_ar);
1129 }
1130
1131 int get_elemsize() const
1132 {
1133 PyUSMArrayObject *raw_ar = usm_array_ptr();
1134
1135 auto const &api = ::dpctl::detail::dpctl_capi::get();
1136 return api.UsmNDArray_GetElementSize_(raw_ar);
1137 }
1138
1139 bool is_c_contiguous() const
1140 {
1141 int flags = get_flags();
1142 auto const &api = ::dpctl::detail::dpctl_capi::get();
1143 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1144 }
1145
1146 bool is_f_contiguous() const
1147 {
1148 int flags = get_flags();
1149 auto const &api = ::dpctl::detail::dpctl_capi::get();
1150 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1151 }
1152
1153 bool is_writable() const
1154 {
1155 int flags = get_flags();
1156 auto const &api = ::dpctl::detail::dpctl_capi::get();
1157 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1158 }
1159
1161 py::object get_usm_data() const
1162 {
1163 PyUSMArrayObject *raw_ar = usm_array_ptr();
1164
1165 auto const &api = ::dpctl::detail::dpctl_capi::get();
1166 // UsmNDArray_GetUSMData_ gives a new reference
1167 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1168
1169 // pass reference ownership to py::object
1170 return py::reinterpret_steal<py::object>(usm_data);
1171 }
1172
1173 bool is_managed_by_smart_ptr() const
1174 {
1175 PyUSMArrayObject *raw_ar = usm_array_ptr();
1176
1177 auto const &api = ::dpctl::detail::dpctl_capi::get();
1178 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1179
1180 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1181 Py_DECREF(usm_data);
1182 return false;
1183 }
1184
1185 Py_MemoryObject *mem_obj =
1186 reinterpret_cast<Py_MemoryObject *>(usm_data);
1187 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1188
1189 Py_DECREF(usm_data);
1190 return bool(opaque_ptr);
1191 }
1192
1193 const std::shared_ptr<void> &get_smart_ptr_owner() const
1194 {
1195 PyUSMArrayObject *raw_ar = usm_array_ptr();
1196
1197 auto const &api = ::dpctl::detail::dpctl_capi::get();
1198
1199 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1200
1201 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1202 Py_DECREF(usm_data);
1203 throw std::runtime_error(
1204 "usm_ndarray object does not have Memory object "
1205 "managing lifetime of USM allocation");
1206 }
1207
1208 Py_MemoryObject *mem_obj =
1209 reinterpret_cast<Py_MemoryObject *>(usm_data);
1210 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1211 Py_DECREF(usm_data);
1212
1213 if (opaque_ptr) {
1214 auto shptr_ptr =
1215 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1216 return *shptr_ptr;
1217 }
1218 else {
1219 throw std::runtime_error(
1220 "Memory object underlying usm_ndarray does not have "
1221 "smart pointer managing lifetime of USM allocation");
1222 }
1223 }
1224
1225private:
1226 PyUSMArrayObject *usm_array_ptr() const
1227 {
1228 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1229 }
1230};
1231} // end namespace tensor
1232
1233namespace utils
1234{
1235namespace detail
1236{
1238{
1239
1240 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1241 {
1242 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1243 const auto &usm_memory_inst =
1244 py::cast<dpctl::memory::usm_memory>(h);
1245 return usm_memory_inst.is_managed_by_smart_ptr();
1246 }
1247 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1248 const auto &usm_array_inst =
1249 py::cast<dpctl::tensor::usm_ndarray>(h);
1250 return usm_array_inst.is_managed_by_smart_ptr();
1251 }
1252
1253 return false;
1254 }
1255
1256 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1257 {
1258 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1259 const auto &usm_memory_inst =
1260 py::cast<dpctl::memory::usm_memory>(h);
1261 return usm_memory_inst.get_smart_ptr_owner();
1262 }
1263 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1264 const auto &usm_array_inst =
1265 py::cast<dpctl::tensor::usm_ndarray>(h);
1266 return usm_array_inst.get_smart_ptr_owner();
1267 }
1268
1269 throw std::runtime_error(
1270 "Attempted extraction of shared_ptr on an unrecognized type");
1271 }
1272};
1273} // end of namespace detail
1274
1275template <std::size_t num>
1276sycl::event keep_args_alive(sycl::queue &q,
1277 const py::object (&py_objs)[num],
1278 const std::vector<sycl::event> &depends = {})
1279{
1280 std::size_t n_objects_held = 0;
1281 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1282
1283 std::size_t n_usm_owners_held = 0;
1284 std::array<std::shared_ptr<void>, num> shp_usm{};
1285
1286 for (std::size_t i = 0; i < num; ++i) {
1287 const auto &py_obj_i = py_objs[i];
1288 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1289 const auto &shp =
1290 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1291 shp_usm[n_usm_owners_held] = shp;
1292 ++n_usm_owners_held;
1293 }
1294 else {
1295 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1296 shp_arr[n_objects_held]->inc_ref();
1297 ++n_objects_held;
1298 }
1299 }
1300
1301 bool use_depends = true;
1302 sycl::event host_task_ev;
1303
1304 if (n_usm_owners_held > 0) {
1305 host_task_ev = q.submit([&](sycl::handler &cgh) {
1306 if (use_depends) {
1307 cgh.depends_on(depends);
1308 use_depends = false;
1309 }
1310 else {
1311 cgh.depends_on(host_task_ev);
1312 }
1313 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1314 // no body, but shared pointers are captured in
1315 // the lambda, ensuring that USM allocation is
1316 // kept alive
1317 });
1318 });
1319 }
1320
1321 if (n_objects_held > 0) {
1322 host_task_ev = q.submit([&](sycl::handler &cgh) {
1323 if (use_depends) {
1324 cgh.depends_on(depends);
1325 use_depends = false;
1326 }
1327 else {
1328 cgh.depends_on(host_task_ev);
1329 }
1330 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1331 py::gil_scoped_acquire acquire;
1332
1333 for (std::size_t i = 0; i < n_objects_held; ++i) {
1334 shp_arr[i]->dec_ref();
1335 }
1336 });
1337 });
1338 }
1339
1340 return host_task_ev;
1341}
1342
1345template <std::size_t num>
1346bool queues_are_compatible(const sycl::queue &exec_q,
1347 const sycl::queue (&alloc_qs)[num])
1348{
1349 for (std::size_t i = 0; i < num; ++i) {
1350
1351 if (exec_q != alloc_qs[i]) {
1352 return false;
1353 }
1354 }
1355 return true;
1356}
1357
1360template <std::size_t num>
1361bool queues_are_compatible(const sycl::queue &exec_q,
1362 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1363{
1364 for (std::size_t i = 0; i < num; ++i) {
1365
1366 if (exec_q != arrs[i].get_queue()) {
1367 return false;
1368 }
1369 }
1370 return true;
1371}
1372} // end namespace utils
1373} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.