DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// TODO: Enable dpctl_capi.h once dpctl.tensor is removed.
32// Also call `import_dpctl_ext__tensor___usmarray();` right after
33// `import_dpctl()` (line 334) to initialize the dpctl_ext tensor C-API.
34//
35// Now we include dpctl C-API headers explicitly in order to
36// integrate dpctl_ext tensor C-API.
37
38// #include "dpctl_capi.h"
39
40// clang-format off
41// Ordering of includes is important here. dpctl_sycl_types and
42// dpctl_sycl_extension_interface define types used by dpctl's Python
43// C-API headers.
44#include "syclinterface/dpctl_sycl_types.h"
45#include "syclinterface/dpctl_sycl_extension_interface.h"
46#ifdef __cplusplus
47#define CYTHON_EXTERN_C extern "C"
48#else
49#define CYTHON_EXTERN_C
50#endif
51#include "dpctl/_sycl_device.h"
52#include "dpctl/_sycl_device_api.h"
53#include "dpctl/_sycl_context.h"
54#include "dpctl/_sycl_context_api.h"
55#include "dpctl/_sycl_event.h"
56#include "dpctl/_sycl_event_api.h"
57#include "dpctl/_sycl_queue.h"
58#include "dpctl/_sycl_queue_api.h"
59#include "dpctl/memory/_memory.h"
60#include "dpctl/memory/_memory_api.h"
61#include "dpctl/program/_program.h"
62#include "dpctl/program/_program_api.h"
63
64// clang-format on
65
66// TODO: Keep these includes once `dpctl.tensor` is removed from dpctl,
67// but replace the hardcoded relative path with a proper include pathы
68#include <dpctl_ext/tensor/_usmarray.h>
69#include <dpctl_ext/tensor/_usmarray_api.h>
70
71/*
72 * Function to import dpctl and make C-API functions available.
73 * C functions can use dpctl's C-API functions without linking to
74 * shared objects defining this symbols, if they call `import_dpctl()`
75 * prior to using those symbols.
76 *
77 * It is declared inline to allow multiple definitions in
78 * different translation units
79 */
80static inline void import_dpctl(void)
81{
82 import_dpctl___sycl_device();
83 import_dpctl___sycl_context();
84 import_dpctl___sycl_event();
85 import_dpctl___sycl_queue();
86 import_dpctl__memory___memory();
87 import_dpctl_ext__tensor___usmarray();
88 import_dpctl__program___program();
89 return;
90}
91
92#include <array>
93#include <complex>
94#include <cstddef> // for std::size_t for C++ linkage
95#include <cstdint>
96#include <memory>
97#include <stddef.h> // for size_t for C linkage
98#include <stdexcept>
99#include <type_traits>
100#include <utility>
101#include <vector>
102
103#include <pybind11/pybind11.h>
104
105#include <sycl/sycl.hpp>
106
107namespace py = pybind11;
108
109namespace dpctl
110{
111namespace detail
112{
113// Lookup a type according to its size, and return a value corresponding to the
114// NumPy typenum.
115template <typename Concrete>
116constexpr int platform_typeid_lookup()
117{
118 return -1;
119}
120
121template <typename Concrete, typename T, typename... Ts, typename... Ints>
122constexpr int platform_typeid_lookup(int I, Ints... Is)
123{
124 return sizeof(Concrete) == sizeof(T)
125 ? I
126 : platform_typeid_lookup<Concrete, Ts...>(Is...);
127}
128
130{
131public:
132 // dpctl type objects
133 PyTypeObject *Py_SyclDeviceType_;
134 PyTypeObject *PySyclDeviceType_;
135 PyTypeObject *Py_SyclContextType_;
136 PyTypeObject *PySyclContextType_;
137 PyTypeObject *Py_SyclEventType_;
138 PyTypeObject *PySyclEventType_;
139 PyTypeObject *Py_SyclQueueType_;
140 PyTypeObject *PySyclQueueType_;
141 PyTypeObject *Py_MemoryType_;
142 PyTypeObject *PyMemoryUSMDeviceType_;
143 PyTypeObject *PyMemoryUSMSharedType_;
144 PyTypeObject *PyMemoryUSMHostType_;
145 PyTypeObject *PyUSMArrayType_;
146 PyTypeObject *PySyclProgramType_;
147 PyTypeObject *PySyclKernelType_;
148
149 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
150 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
151
152 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
153 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
154
155 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
156 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
157
158 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
159 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
160
161 // memory
162 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
163 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
164 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
165 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
166 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
167 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
168 size_t,
169 DPCTLSyclQueueRef,
170 PyObject *);
171
172 // program
173 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
174 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
175
176 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
177 PySyclProgramObject *);
178 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
179
180 // tensor
181 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
182 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
183 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
184 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
185 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
186 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
187 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
188 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
189 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
190 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
191 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
192 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
193 const py::ssize_t *,
194 int,
195 Py_MemoryObject *,
196 py::ssize_t,
197 char);
198 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
199 int,
200 DPCTLSyclUSMRef,
201 DPCTLSyclQueueRef,
202 PyObject *);
203 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
204 const py::ssize_t *,
205 int,
206 const py::ssize_t *,
207 DPCTLSyclUSMRef,
208 DPCTLSyclQueueRef,
209 py::ssize_t,
210 PyObject *);
211
212 int USM_ARRAY_C_CONTIGUOUS_;
213 int USM_ARRAY_F_CONTIGUOUS_;
214 int USM_ARRAY_WRITABLE_;
215 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
216 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
217 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
218 UAR_HALF_;
219 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
220 UAR_INT64_, UAR_UINT64_;
221
222 bool PySyclDevice_Check_(PyObject *obj) const
223 {
224 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
225 }
226 bool PySyclContext_Check_(PyObject *obj) const
227 {
228 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
229 }
230 bool PySyclEvent_Check_(PyObject *obj) const
231 {
232 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
233 }
234 bool PySyclQueue_Check_(PyObject *obj) const
235 {
236 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
237 }
238 bool PySyclKernel_Check_(PyObject *obj) const
239 {
240 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
241 }
242 bool PySyclProgram_Check_(PyObject *obj) const
243 {
244 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
245 }
246
248 {
249 as_usm_memory_.reset();
250 default_usm_ndarray_.reset();
251 default_usm_memory_.reset();
252 default_sycl_queue_.reset();
253 };
254
255 static auto &get()
256 {
257 static dpctl_capi api{};
258 return api;
259 }
260
261 py::object default_sycl_queue_pyobj()
262 {
263 return *default_sycl_queue_;
264 }
265 py::object default_usm_memory_pyobj()
266 {
267 return *default_usm_memory_;
268 }
269 py::object default_usm_ndarray_pyobj()
270 {
271 return *default_usm_ndarray_;
272 }
273 py::object as_usm_memory_pyobj()
274 {
275 return *as_usm_memory_;
276 }
277
278private:
279 struct Deleter
280 {
281 void operator()(py::object *p) const
282 {
283 const bool initialized = Py_IsInitialized();
284#if PY_VERSION_HEX < 0x30d0000
285 const bool finalizing = _Py_IsFinalizing();
286#else
287 const bool finalizing = Py_IsFinalizing();
288#endif
289 const bool guard = initialized && !finalizing;
290
291 if (guard) {
292 delete p;
293 }
294 }
295 };
296
297 std::shared_ptr<py::object> default_sycl_queue_;
298 std::shared_ptr<py::object> default_usm_memory_;
299 std::shared_ptr<py::object> default_usm_ndarray_;
300 std::shared_ptr<py::object> as_usm_memory_;
301
302 dpctl_capi()
303 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
304 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
305 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
306 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
307 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
308 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
309 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
310 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
311 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
312 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
313 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
314 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
315 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
316 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
317 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
318 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
319 SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
320 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
321 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
322 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
323 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
324 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
325 UsmNDArray_MakeSimpleFromMemory_(nullptr),
326 UsmNDArray_MakeSimpleFromPtr_(nullptr),
327 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
328 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
329 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
330 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
331 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
332 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
333 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
334 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
335 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
336 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
337
338 {
339 // Import Cython-generated C-API for dpctl
340 // This imports python modules and initializes
341 // static variables such as function pointers for C-API,
342 // e.g. SyclDevice_GetDeviceRef, etc.
343 // pointers to Python types, i.e. PySyclDeviceType, etc.
344 // and exported constants, i.e. USM_ARRAY_C_CONTIGUOUS, etc.
345 import_dpctl();
346
347 // Python type objects for classes implemented by dpctl
348 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
349 this->PySyclDeviceType_ = &PySyclDeviceType;
350 this->Py_SyclContextType_ = &Py_SyclContextType;
351 this->PySyclContextType_ = &PySyclContextType;
352 this->Py_SyclEventType_ = &Py_SyclEventType;
353 this->PySyclEventType_ = &PySyclEventType;
354 this->Py_SyclQueueType_ = &Py_SyclQueueType;
355 this->PySyclQueueType_ = &PySyclQueueType;
356 this->Py_MemoryType_ = &Py_MemoryType;
357 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
358 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
359 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
360 this->PyUSMArrayType_ = &PyUSMArrayType;
361 this->PySyclProgramType_ = &PySyclProgramType;
362 this->PySyclKernelType_ = &PySyclKernelType;
363
364 // SyclDevice API
365 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
366 this->SyclDevice_Make_ = SyclDevice_Make;
367
368 // SyclContext API
369 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
370 this->SyclContext_Make_ = SyclContext_Make;
371
372 // SyclEvent API
373 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
374 this->SyclEvent_Make_ = SyclEvent_Make;
375
376 // SyclQueue API
377 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
378 this->SyclQueue_Make_ = SyclQueue_Make;
379
380 // dpctl.memory API
381 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
382 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
383 this->Memory_GetContextRef_ = Memory_GetContextRef;
384 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
385 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
386 this->Memory_Make_ = Memory_Make;
387
388 // dpctl.program API
389 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
390 this->SyclKernel_Make_ = SyclKernel_Make;
391 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
392 this->SyclProgram_Make_ = SyclProgram_Make;
393
394 // dpctl.tensor.usm_ndarray API
395 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
396 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
397 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
398 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
399 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
400 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
401 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
402 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
403 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
404 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
405 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
406 this->UsmNDArray_MakeSimpleFromMemory_ =
407 UsmNDArray_MakeSimpleFromMemory;
408 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
409 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
410
411 // constants
412 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
413 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
414 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
415 this->UAR_BOOL_ = UAR_BOOL;
416 this->UAR_BYTE_ = UAR_BYTE;
417 this->UAR_UBYTE_ = UAR_UBYTE;
418 this->UAR_SHORT_ = UAR_SHORT;
419 this->UAR_USHORT_ = UAR_USHORT;
420 this->UAR_INT_ = UAR_INT;
421 this->UAR_UINT_ = UAR_UINT;
422 this->UAR_LONG_ = UAR_LONG;
423 this->UAR_ULONG_ = UAR_ULONG;
424 this->UAR_LONGLONG_ = UAR_LONGLONG;
425 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
426 this->UAR_FLOAT_ = UAR_FLOAT;
427 this->UAR_DOUBLE_ = UAR_DOUBLE;
428 this->UAR_CFLOAT_ = UAR_CFLOAT;
429 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
430 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
431 this->UAR_HALF_ = UAR_HALF;
432
433 // deduced disjoint types
434 this->UAR_INT8_ = UAR_BYTE;
435 this->UAR_UINT8_ = UAR_UBYTE;
436 this->UAR_INT16_ = UAR_SHORT;
437 this->UAR_UINT16_ = UAR_USHORT;
438 this->UAR_INT32_ =
439 platform_typeid_lookup<std::int32_t, long, int, short>(
440 UAR_LONG, UAR_INT, UAR_SHORT);
441 this->UAR_UINT32_ =
442 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
443 unsigned short>(UAR_ULONG, UAR_UINT,
444 UAR_USHORT);
445 this->UAR_INT64_ =
446 platform_typeid_lookup<std::int64_t, long, long long, int>(
447 UAR_LONG, UAR_LONGLONG, UAR_INT);
448 this->UAR_UINT64_ =
449 platform_typeid_lookup<std::uint64_t, unsigned long,
450 unsigned long long, unsigned int>(
451 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
452
453 // create shared pointers to python objects used in type-casters
454 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
455 sycl::queue q_{};
456 PySyclQueueObject *py_q_tmp =
457 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
458 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
459 reinterpret_cast<PyObject *>(py_q_tmp));
460
461 default_sycl_queue_ = std::shared_ptr<py::object>(
462 new py::object(py_sycl_queue), Deleter{});
463
464 py::module_ mod_memory = py::module_::import("dpctl.memory");
465 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
466 as_usm_memory_ = std::shared_ptr<py::object>(
467 new py::object{py_as_usm_memory}, Deleter{});
468
469 auto mem_kl = mod_memory.attr("MemoryUSMHost");
470 const py::object &py_default_usm_memory =
471 mem_kl(1, py::arg("queue") = py_sycl_queue);
472 default_usm_memory_ = std::shared_ptr<py::object>(
473 new py::object{py_default_usm_memory}, Deleter{});
474
475 // TODO: revert to `py::module_::import("dpctl.tensor._usmarray");`
476 // when dpnp fully migrates dpctl/tensor
477 py::module_ mod_usmarray =
478 py::module_::import("dpctl_ext.tensor._usmarray");
479 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
480
481 const py::object &py_default_usm_ndarray =
482 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
483 py::arg("buffer") = py_default_usm_memory);
484
485 default_usm_ndarray_ = std::shared_ptr<py::object>(
486 new py::object{py_default_usm_ndarray}, Deleter{});
487 }
488
489 dpctl_capi(dpctl_capi const &) = default;
490 dpctl_capi &operator=(dpctl_capi const &) = default;
491 dpctl_capi &operator=(dpctl_capi &&) = default;
492
493}; // struct dpctl_capi
494} // namespace detail
495} // namespace dpctl
496
497namespace pybind11::detail
498{
499#define DPCTL_TYPE_CASTER(type, py_name) \
500protected: \
501 std::unique_ptr<type> value; \
502 \
503public: \
504 static constexpr auto name = py_name; \
505 template < \
506 typename T_, \
507 ::pybind11::detail::enable_if_t< \
508 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
509 int> = 0> \
510 static ::pybind11::handle cast(T_ *src, \
511 ::pybind11::return_value_policy policy, \
512 ::pybind11::handle parent) \
513 { \
514 if (!src) \
515 return ::pybind11::none().release(); \
516 if (policy == ::pybind11::return_value_policy::take_ownership) { \
517 auto h = cast(std::move(*src), policy, parent); \
518 delete src; \
519 return h; \
520 } \
521 return cast(*src, policy, parent); \
522 } \
523 operator type *() \
524 { \
525 return value.get(); \
526 } /* NOLINT(bugprone-macro-parentheses) */ \
527 operator type &() \
528 { \
529 return *value; \
530 } /* NOLINT(bugprone-macro-parentheses) */ \
531 operator type &&() && \
532 { \
533 return std::move(*value); \
534 } /* NOLINT(bugprone-macro-parentheses) */ \
535 template <typename T_> \
536 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
537
538/* This type caster associates ``sycl::queue`` C++ class with
539 * :class:`dpctl.SyclQueue` for the purposes of generation of
540 * Python bindings by pybind11.
541 */
542template <>
543struct type_caster<sycl::queue>
544{
545public:
546 bool load(handle src, bool)
547 {
548 PyObject *source = src.ptr();
549 auto const &api = ::dpctl::detail::dpctl_capi::get();
550 if (api.PySyclQueue_Check_(source)) {
551 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
552 reinterpret_cast<PySyclQueueObject *>(source));
553 value = std::make_unique<sycl::queue>(
554 *(reinterpret_cast<sycl::queue *>(QRef)));
555 return true;
556 }
557 else {
558 throw py::type_error(
559 "Input is of unexpected type, expected dpctl.SyclQueue");
560 }
561 }
562
563 static handle cast(sycl::queue src, return_value_policy, handle)
564 {
565 auto const &api = ::dpctl::detail::dpctl_capi::get();
566 auto tmp =
567 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
568 return handle(reinterpret_cast<PyObject *>(tmp));
569 }
570
571 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
572};
573
574/* This type caster associates ``sycl::device`` C++ class with
575 * :class:`dpctl.SyclDevice` for the purposes of generation of
576 * Python bindings by pybind11.
577 */
578template <>
579struct type_caster<sycl::device>
580{
581public:
582 bool load(handle src, bool)
583 {
584 PyObject *source = src.ptr();
585 auto const &api = ::dpctl::detail::dpctl_capi::get();
586 if (api.PySyclDevice_Check_(source)) {
587 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
588 reinterpret_cast<PySyclDeviceObject *>(source));
589 value = std::make_unique<sycl::device>(
590 *(reinterpret_cast<sycl::device *>(DRef)));
591 return true;
592 }
593 else {
594 throw py::type_error(
595 "Input is of unexpected type, expected dpctl.SyclDevice");
596 }
597 }
598
599 static handle cast(sycl::device src, return_value_policy, handle)
600 {
601 auto const &api = ::dpctl::detail::dpctl_capi::get();
602 auto tmp =
603 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
604 return handle(reinterpret_cast<PyObject *>(tmp));
605 }
606
607 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
608};
609
610/* This type caster associates ``sycl::context`` C++ class with
611 * :class:`dpctl.SyclContext` for the purposes of generation of
612 * Python bindings by pybind11.
613 */
614template <>
615struct type_caster<sycl::context>
616{
617public:
618 bool load(handle src, bool)
619 {
620 PyObject *source = src.ptr();
621 auto const &api = ::dpctl::detail::dpctl_capi::get();
622 if (api.PySyclContext_Check_(source)) {
623 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
624 reinterpret_cast<PySyclContextObject *>(source));
625 value = std::make_unique<sycl::context>(
626 *(reinterpret_cast<sycl::context *>(CRef)));
627 return true;
628 }
629 else {
630 throw py::type_error(
631 "Input is of unexpected type, expected dpctl.SyclContext");
632 }
633 }
634
635 static handle cast(sycl::context src, return_value_policy, handle)
636 {
637 auto const &api = ::dpctl::detail::dpctl_capi::get();
638 auto tmp =
639 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
640 return handle(reinterpret_cast<PyObject *>(tmp));
641 }
642
643 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
644};
645
646/* This type caster associates ``sycl::event`` C++ class with
647 * :class:`dpctl.SyclEvent` for the purposes of generation of
648 * Python bindings by pybind11.
649 */
650template <>
651struct type_caster<sycl::event>
652{
653public:
654 bool load(handle src, bool)
655 {
656 PyObject *source = src.ptr();
657 auto const &api = ::dpctl::detail::dpctl_capi::get();
658 if (api.PySyclEvent_Check_(source)) {
659 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
660 reinterpret_cast<PySyclEventObject *>(source));
661 value = std::make_unique<sycl::event>(
662 *(reinterpret_cast<sycl::event *>(ERef)));
663 return true;
664 }
665 else {
666 throw py::type_error(
667 "Input is of unexpected type, expected dpctl.SyclEvent");
668 }
669 }
670
671 static handle cast(sycl::event src, return_value_policy, handle)
672 {
673 auto const &api = ::dpctl::detail::dpctl_capi::get();
674 auto tmp =
675 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
676 return handle(reinterpret_cast<PyObject *>(tmp));
677 }
678
679 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
680};
681
682/* This type caster associates ``sycl::kernel`` C++ class with
683 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
684 * Python bindings by pybind11.
685 */
686template <>
687struct type_caster<sycl::kernel>
688{
689public:
690 bool load(handle src, bool)
691 {
692 PyObject *source = src.ptr();
693 auto const &api = ::dpctl::detail::dpctl_capi::get();
694 if (api.PySyclKernel_Check_(source)) {
695 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
696 reinterpret_cast<PySyclKernelObject *>(source));
697 value = std::make_unique<sycl::kernel>(
698 *(reinterpret_cast<sycl::kernel *>(KRef)));
699 return true;
700 }
701 else {
702 throw py::type_error("Input is of unexpected type, expected "
703 "dpctl.program.SyclKernel");
704 }
705 }
706
707 static handle cast(sycl::kernel src, return_value_policy, handle)
708 {
709 auto const &api = ::dpctl::detail::dpctl_capi::get();
710 auto tmp =
711 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
712 "dpctl4pybind11_kernel");
713 return handle(reinterpret_cast<PyObject *>(tmp));
714 }
715
716 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
717};
718
719/* This type caster associates
720 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
721 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
722 * Python bindings by pybind11.
723 */
724template <>
725struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
726{
727public:
728 bool load(handle src, bool)
729 {
730 PyObject *source = src.ptr();
731 auto const &api = ::dpctl::detail::dpctl_capi::get();
732 if (api.PySyclProgram_Check_(source)) {
733 DPCTLSyclKernelBundleRef KBRef =
734 api.SyclProgram_GetKernelBundleRef_(
735 reinterpret_cast<PySyclProgramObject *>(source));
736 value = std::make_unique<
737 sycl::kernel_bundle<sycl::bundle_state::executable>>(
738 *(reinterpret_cast<
739 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
740 KBRef)));
741 return true;
742 }
743 else {
744 throw py::type_error("Input is of unexpected type, expected "
745 "dpctl.program.SyclProgram");
746 }
747 }
748
749 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
750 return_value_policy,
751 handle)
752 {
753 auto const &api = ::dpctl::detail::dpctl_capi::get();
754 auto tmp = api.SyclProgram_Make_(
755 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
756 return handle(reinterpret_cast<PyObject *>(tmp));
757 }
758
759 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
760 _("dpctl.program.SyclProgram"));
761};
762
763/* This type caster associates
764 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
765 * of generation of Python bindings by pybind11.
766 */
767template <>
768struct type_caster<sycl::half>
769{
770public:
771 bool load(handle src, bool convert)
772 {
773 double py_value;
774
775 if (!src) {
776 return false;
777 }
778
779 PyObject *source = src.ptr();
780
781 if (convert || PyFloat_Check(source)) {
782 py_value = PyFloat_AsDouble(source);
783 }
784 else {
785 return false;
786 }
787
788 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
789
790 if (py_err) {
791 PyErr_Clear();
792 if (convert && (PyNumber_Check(source) != 0)) {
793 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
794 return load(tmp, false);
795 }
796 return false;
797 }
798 value = static_cast<sycl::half>(py_value);
799 return true;
800 }
801
802 static handle cast(sycl::half src, return_value_policy, handle)
803 {
804 return PyFloat_FromDouble(static_cast<double>(src));
805 }
806
807 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
808};
809} // namespace pybind11::detail
810
811namespace dpctl
812{
813namespace memory
814{
815// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
816// this allows to avoid compilation error
817using pybind11::error_already_set;
818
819class usm_memory : public py::object
820{
821public:
822 PYBIND11_OBJECT_CVT(
824 py::object,
825 [](PyObject *o) -> bool {
826 return PyObject_TypeCheck(
827 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
828 0;
829 },
830 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
831
832 usm_memory()
833 : py::object(
834 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
835 borrowed_t{})
836 {
837 if (!m_ptr)
838 throw py::error_already_set();
839 }
840
844 usm_memory(void *usm_ptr,
845 std::size_t nbytes,
846 const sycl::queue &q,
847 std::shared_ptr<void> shptr)
848 {
849 auto const &api = ::dpctl::detail::dpctl_capi::get();
850 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
851 auto q_uptr = std::make_unique<sycl::queue>(q);
852 DPCTLSyclQueueRef QRef =
853 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
854
855 auto vacuous_destructor = []() {};
856 py::capsule mock_owner(vacuous_destructor);
857
858 // create memory object owned by mock_owner, it is a new reference
859 PyObject *_memory =
860 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
861 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
862
863 using py_uptrT =
864 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
865
866 if (!_memory) {
867 throw py::error_already_set();
868 }
869
870 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
871 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
872
873 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
874 // replace mock_owner capsule as the owner
875 memobj->refobj = Py_None;
876 // set opaque ptr field, usm_memory now knowns that USM is managed
877 // by smart pointer
878 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
879
880 // _memory will delete created copies of sycl::queue, and
881 // std::shared_ptr and the deleter of the shared_ptr<void> is
882 // supposed to free the USM allocation
883 m_ptr = _memory;
884 q_uptr.release();
885 memory_uptr.release();
886 }
887
888 sycl::queue get_queue() const
889 {
890 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
891 auto const &api = ::dpctl::detail::dpctl_capi::get();
892 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
893 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
894 return *obj_q;
895 }
896
897 char *get_pointer() const
898 {
899 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
900 auto const &api = ::dpctl::detail::dpctl_capi::get();
901 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
902 return reinterpret_cast<char *>(MRef);
903 }
904
905 std::size_t get_nbytes() const
906 {
907 auto const &api = ::dpctl::detail::dpctl_capi::get();
908 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
909 return api.Memory_GetNumBytes_(mem_obj);
910 }
911
912 bool is_managed_by_smart_ptr() const
913 {
914 auto const &api = ::dpctl::detail::dpctl_capi::get();
915 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
916 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
917
918 return bool(opaque_ptr);
919 }
920
921 const std::shared_ptr<void> &get_smart_ptr_owner() const
922 {
923 auto const &api = ::dpctl::detail::dpctl_capi::get();
924 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
925 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
926
927 if (opaque_ptr) {
928 auto shptr_ptr =
929 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
930 return *shptr_ptr;
931 }
932 else {
933 throw std::runtime_error(
934 "Memory object does not have smart pointer "
935 "managing lifetime of USM allocation");
936 }
937 }
938
939protected:
940 static PyObject *as_usm_memory(PyObject *o)
941 {
942 if (o == nullptr) {
943 PyErr_SetString(PyExc_ValueError,
944 "cannot create a usm_memory from a nullptr");
945 return nullptr;
946 }
947
948 auto converter =
949 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
950
951 py::object res;
952 try {
953 res = converter(py::handle(o));
954 } catch (const py::error_already_set &e) {
955 return nullptr;
956 }
957 return res.ptr();
958 }
959};
960} // end namespace memory
961
962namespace tensor
963{
964inline std::vector<py::ssize_t>
965 c_contiguous_strides(int nd,
966 const py::ssize_t *shape,
967 py::ssize_t element_size = 1)
968{
969 if (nd > 0) {
970 std::vector<py::ssize_t> c_strides(nd, element_size);
971 for (int ic = nd - 1; ic > 0;) {
972 py::ssize_t next_v = c_strides[ic] * shape[ic];
973 c_strides[--ic] = next_v;
974 }
975 return c_strides;
976 }
977 else {
978 return std::vector<py::ssize_t>();
979 }
980}
981
982inline std::vector<py::ssize_t>
983 f_contiguous_strides(int nd,
984 const py::ssize_t *shape,
985 py::ssize_t element_size = 1)
986{
987 if (nd > 0) {
988 std::vector<py::ssize_t> f_strides(nd, element_size);
989 for (int i = 0; i < nd - 1;) {
990 py::ssize_t next_v = f_strides[i] * shape[i];
991 f_strides[++i] = next_v;
992 }
993 return f_strides;
994 }
995 else {
996 return std::vector<py::ssize_t>();
997 }
998}
999
1000inline std::vector<py::ssize_t>
1001 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
1002 py::ssize_t element_size = 1)
1003{
1004 return c_contiguous_strides(shape.size(), shape.data(), element_size);
1005}
1006
1007inline std::vector<py::ssize_t>
1008 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
1009 py::ssize_t element_size = 1)
1010{
1011 return f_contiguous_strides(shape.size(), shape.data(), element_size);
1012}
1013
1014class usm_ndarray : public py::object
1015{
1016public:
1017 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
1018 return PyObject_TypeCheck(
1019 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
1020 })
1021
1022 usm_ndarray()
1023 : py::object(
1024 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
1025 borrowed_t{})
1026 {
1027 if (!m_ptr)
1028 throw py::error_already_set();
1029 }
1030
1031 char *get_data() const
1032 {
1033 PyUSMArrayObject *raw_ar = usm_array_ptr();
1034
1035 auto const &api = ::dpctl::detail::dpctl_capi::get();
1036 return api.UsmNDArray_GetData_(raw_ar);
1037 }
1038
1039 template <typename T>
1040 T *get_data() const
1041 {
1042 return reinterpret_cast<T *>(get_data());
1043 }
1044
1045 int get_ndim() const
1046 {
1047 PyUSMArrayObject *raw_ar = usm_array_ptr();
1048
1049 auto const &api = ::dpctl::detail::dpctl_capi::get();
1050 return api.UsmNDArray_GetNDim_(raw_ar);
1051 }
1052
1053 const py::ssize_t *get_shape_raw() const
1054 {
1055 PyUSMArrayObject *raw_ar = usm_array_ptr();
1056
1057 auto const &api = ::dpctl::detail::dpctl_capi::get();
1058 return api.UsmNDArray_GetShape_(raw_ar);
1059 }
1060
1061 std::vector<py::ssize_t> get_shape_vector() const
1062 {
1063 auto raw_sh = get_shape_raw();
1064 auto nd = get_ndim();
1065
1066 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
1067 return shape_vector;
1068 }
1069
1070 py::ssize_t get_shape(int i) const
1071 {
1072 auto shape_ptr = get_shape_raw();
1073 return shape_ptr[i];
1074 }
1075
1076 const py::ssize_t *get_strides_raw() const
1077 {
1078 PyUSMArrayObject *raw_ar = usm_array_ptr();
1079
1080 auto const &api = ::dpctl::detail::dpctl_capi::get();
1081 return api.UsmNDArray_GetStrides_(raw_ar);
1082 }
1083
1084 std::vector<py::ssize_t> get_strides_vector() const
1085 {
1086 auto raw_st = get_strides_raw();
1087 auto nd = get_ndim();
1088
1089 if (raw_st == nullptr) {
1090 auto is_c_contig = is_c_contiguous();
1091 auto is_f_contig = is_f_contiguous();
1092 auto raw_sh = get_shape_raw();
1093 if (is_c_contig) {
1094 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1095 return contig_strides;
1096 }
1097 else if (is_f_contig) {
1098 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1099 return contig_strides;
1100 }
1101 else {
1102 throw std::runtime_error("Invalid array encountered when "
1103 "building strides");
1104 }
1105 }
1106 else {
1107 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1108 return st_vec;
1109 }
1110 }
1111
1112 py::ssize_t get_size() const
1113 {
1114 PyUSMArrayObject *raw_ar = usm_array_ptr();
1115
1116 auto const &api = ::dpctl::detail::dpctl_capi::get();
1117 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
1118 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1119
1120 py::ssize_t nelems = 1;
1121 for (int i = 0; i < ndim; ++i) {
1122 nelems *= shape[i];
1123 }
1124
1125 assert(nelems >= 0);
1126 return nelems;
1127 }
1128
1129 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1130 {
1131 PyUSMArrayObject *raw_ar = usm_array_ptr();
1132
1133 auto const &api = ::dpctl::detail::dpctl_capi::get();
1134 int nd = api.UsmNDArray_GetNDim_(raw_ar);
1135 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1136 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
1137
1138 py::ssize_t offset_min = 0;
1139 py::ssize_t offset_max = 0;
1140 if (strides == nullptr) {
1141 py::ssize_t stride(1);
1142 for (int i = 0; i < nd; ++i) {
1143 offset_max += stride * (shape[i] - 1);
1144 stride *= shape[i];
1145 }
1146 }
1147 else {
1148 for (int i = 0; i < nd; ++i) {
1149 py::ssize_t delta = strides[i] * (shape[i] - 1);
1150 if (strides[i] > 0) {
1151 offset_max += delta;
1152 }
1153 else {
1154 offset_min += delta;
1155 }
1156 }
1157 }
1158 return std::make_pair(offset_min, offset_max);
1159 }
1160
1161 sycl::queue get_queue() const
1162 {
1163 PyUSMArrayObject *raw_ar = usm_array_ptr();
1164
1165 auto const &api = ::dpctl::detail::dpctl_capi::get();
1166 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1167 return *(reinterpret_cast<sycl::queue *>(QRef));
1168 }
1169
1170 sycl::device get_device() const
1171 {
1172 PyUSMArrayObject *raw_ar = usm_array_ptr();
1173
1174 auto const &api = ::dpctl::detail::dpctl_capi::get();
1175 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1176 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1177 }
1178
1179 int get_typenum() const
1180 {
1181 PyUSMArrayObject *raw_ar = usm_array_ptr();
1182
1183 auto const &api = ::dpctl::detail::dpctl_capi::get();
1184 return api.UsmNDArray_GetTypenum_(raw_ar);
1185 }
1186
1187 int get_flags() const
1188 {
1189 PyUSMArrayObject *raw_ar = usm_array_ptr();
1190
1191 auto const &api = ::dpctl::detail::dpctl_capi::get();
1192 return api.UsmNDArray_GetFlags_(raw_ar);
1193 }
1194
1195 int get_elemsize() const
1196 {
1197 PyUSMArrayObject *raw_ar = usm_array_ptr();
1198
1199 auto const &api = ::dpctl::detail::dpctl_capi::get();
1200 return api.UsmNDArray_GetElementSize_(raw_ar);
1201 }
1202
1203 bool is_c_contiguous() const
1204 {
1205 int flags = get_flags();
1206 auto const &api = ::dpctl::detail::dpctl_capi::get();
1207 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1208 }
1209
1210 bool is_f_contiguous() const
1211 {
1212 int flags = get_flags();
1213 auto const &api = ::dpctl::detail::dpctl_capi::get();
1214 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1215 }
1216
1217 bool is_writable() const
1218 {
1219 int flags = get_flags();
1220 auto const &api = ::dpctl::detail::dpctl_capi::get();
1221 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1222 }
1223
1225 py::object get_usm_data() const
1226 {
1227 PyUSMArrayObject *raw_ar = usm_array_ptr();
1228
1229 auto const &api = ::dpctl::detail::dpctl_capi::get();
1230 // UsmNDArray_GetUSMData_ gives a new reference
1231 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1232
1233 // pass reference ownership to py::object
1234 return py::reinterpret_steal<py::object>(usm_data);
1235 }
1236
1237 bool is_managed_by_smart_ptr() const
1238 {
1239 PyUSMArrayObject *raw_ar = usm_array_ptr();
1240
1241 auto const &api = ::dpctl::detail::dpctl_capi::get();
1242 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1243
1244 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1245 Py_DECREF(usm_data);
1246 return false;
1247 }
1248
1249 Py_MemoryObject *mem_obj =
1250 reinterpret_cast<Py_MemoryObject *>(usm_data);
1251 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1252
1253 Py_DECREF(usm_data);
1254 return bool(opaque_ptr);
1255 }
1256
1257 const std::shared_ptr<void> &get_smart_ptr_owner() const
1258 {
1259 PyUSMArrayObject *raw_ar = usm_array_ptr();
1260
1261 auto const &api = ::dpctl::detail::dpctl_capi::get();
1262
1263 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1264
1265 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1266 Py_DECREF(usm_data);
1267 throw std::runtime_error(
1268 "usm_ndarray object does not have Memory object "
1269 "managing lifetime of USM allocation");
1270 }
1271
1272 Py_MemoryObject *mem_obj =
1273 reinterpret_cast<Py_MemoryObject *>(usm_data);
1274 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1275 Py_DECREF(usm_data);
1276
1277 if (opaque_ptr) {
1278 auto shptr_ptr =
1279 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1280 return *shptr_ptr;
1281 }
1282 else {
1283 throw std::runtime_error(
1284 "Memory object underlying usm_ndarray does not have "
1285 "smart pointer managing lifetime of USM allocation");
1286 }
1287 }
1288
1289private:
1290 PyUSMArrayObject *usm_array_ptr() const
1291 {
1292 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1293 }
1294};
1295} // end namespace tensor
1296
1297namespace utils
1298{
1299namespace detail
1300{
1302{
1303
1304 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1305 {
1306 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1307 const auto &usm_memory_inst =
1308 py::cast<dpctl::memory::usm_memory>(h);
1309 return usm_memory_inst.is_managed_by_smart_ptr();
1310 }
1311 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1312 const auto &usm_array_inst =
1313 py::cast<dpctl::tensor::usm_ndarray>(h);
1314 return usm_array_inst.is_managed_by_smart_ptr();
1315 }
1316
1317 return false;
1318 }
1319
1320 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1321 {
1322 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1323 const auto &usm_memory_inst =
1324 py::cast<dpctl::memory::usm_memory>(h);
1325 return usm_memory_inst.get_smart_ptr_owner();
1326 }
1327 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1328 const auto &usm_array_inst =
1329 py::cast<dpctl::tensor::usm_ndarray>(h);
1330 return usm_array_inst.get_smart_ptr_owner();
1331 }
1332
1333 throw std::runtime_error(
1334 "Attempted extraction of shared_ptr on an unrecognized type");
1335 }
1336};
1337} // end of namespace detail
1338
1339template <std::size_t num>
1340sycl::event keep_args_alive(sycl::queue &q,
1341 const py::object (&py_objs)[num],
1342 const std::vector<sycl::event> &depends = {})
1343{
1344 std::size_t n_objects_held = 0;
1345 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1346
1347 std::size_t n_usm_owners_held = 0;
1348 std::array<std::shared_ptr<void>, num> shp_usm{};
1349
1350 for (std::size_t i = 0; i < num; ++i) {
1351 const auto &py_obj_i = py_objs[i];
1352 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1353 const auto &shp =
1354 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1355 shp_usm[n_usm_owners_held] = shp;
1356 ++n_usm_owners_held;
1357 }
1358 else {
1359 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1360 shp_arr[n_objects_held]->inc_ref();
1361 ++n_objects_held;
1362 }
1363 }
1364
1365 bool use_depends = true;
1366 sycl::event host_task_ev;
1367
1368 if (n_usm_owners_held > 0) {
1369 host_task_ev = q.submit([&](sycl::handler &cgh) {
1370 if (use_depends) {
1371 cgh.depends_on(depends);
1372 use_depends = false;
1373 }
1374 else {
1375 cgh.depends_on(host_task_ev);
1376 }
1377 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1378 // no body, but shared pointers are captured in
1379 // the lambda, ensuring that USM allocation is
1380 // kept alive
1381 });
1382 });
1383 }
1384
1385 if (n_objects_held > 0) {
1386 host_task_ev = q.submit([&](sycl::handler &cgh) {
1387 if (use_depends) {
1388 cgh.depends_on(depends);
1389 use_depends = false;
1390 }
1391 else {
1392 cgh.depends_on(host_task_ev);
1393 }
1394 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1395 py::gil_scoped_acquire acquire;
1396
1397 for (std::size_t i = 0; i < n_objects_held; ++i) {
1398 shp_arr[i]->dec_ref();
1399 }
1400 });
1401 });
1402 }
1403
1404 return host_task_ev;
1405}
1406
1409template <std::size_t num>
1410bool queues_are_compatible(const sycl::queue &exec_q,
1411 const sycl::queue (&alloc_qs)[num])
1412{
1413 for (std::size_t i = 0; i < num; ++i) {
1414
1415 if (exec_q != alloc_qs[i]) {
1416 return false;
1417 }
1418 }
1419 return true;
1420}
1421
1424template <std::size_t num>
1425bool queues_are_compatible(const sycl::queue &exec_q,
1426 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1427{
1428 for (std::size_t i = 0; i < num; ++i) {
1429
1430 if (exec_q != arrs[i].get_queue()) {
1431 return false;
1432 }
1433 }
1434 return true;
1435}
1436} // end namespace utils
1437} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.