DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// TODO: Enable dpctl_capi.h once dpctl.tensor is removed.
32// Also call `import_dpctl_ext__tensor___usmarray();` right after
33// `import_dpctl()` (line 334) to initialize the dpctl_ext tensor C-API.
34//
35// Now we include dpctl C-API headers explicitly in order to
36// integrate dpctl_ext tensor C-API.
37
38// #include "dpctl_capi.h"
39
40// clang-format off
41// Ordering of includes is important here. dpctl_sycl_types and
42// dpctl_sycl_extension_interface define types used by dpctl's Python
43// C-API headers.
44#include "syclinterface/dpctl_sycl_types.h"
45#include "syclinterface/dpctl_sycl_extension_interface.h"
46#ifdef __cplusplus
47#define CYTHON_EXTERN_C extern "C"
48#else
49#define CYTHON_EXTERN_C
50#endif
51#include "dpctl/_sycl_device.h"
52#include "dpctl/_sycl_device_api.h"
53#include "dpctl/_sycl_context.h"
54#include "dpctl/_sycl_context_api.h"
55#include "dpctl/_sycl_event.h"
56#include "dpctl/_sycl_event_api.h"
57#include "dpctl/_sycl_queue.h"
58#include "dpctl/_sycl_queue_api.h"
59#include "dpctl/memory/_memory.h"
60#include "dpctl/memory/_memory_api.h"
61#include "dpctl/program/_program.h"
62#include "dpctl/program/_program_api.h"
63
64// clang-format on
65
66// TODO: Keep these includes once `dpctl.tensor` is removed from dpctl,
67// but replace the hardcoded relative path with a proper include pathы
68#include <dpctl_ext/tensor/_usmarray.h>
69#include <dpctl_ext/tensor/_usmarray_api.h>
70
71/*
72 * Function to import dpctl and make C-API functions available.
73 * C functions can use dpctl's C-API functions without linking to
74 * shared objects defining this symbols, if they call `import_dpctl()`
75 * prior to using those symbols.
76 *
77 * It is declared inline to allow multiple definitions in
78 * different translation units
79 */
80static inline void import_dpctl(void)
81{
82 import_dpctl___sycl_device();
83 import_dpctl___sycl_context();
84 import_dpctl___sycl_event();
85 import_dpctl___sycl_queue();
86 import_dpctl__memory___memory();
87 import_dpctl_ext__tensor___usmarray();
88 import_dpctl__program___program();
89 return;
90}
91
92#include <complex>
93#include <cstddef> // for std::size_t for C++ linkage
94#include <memory>
95#include <stddef.h> // for size_t for C linkage
96#include <stdexcept>
97#include <utility>
98#include <vector>
99
100#include <pybind11/pybind11.h>
101
102#include <sycl/sycl.hpp>
103
104namespace py = pybind11;
105
106namespace dpctl
107{
108namespace detail
109{
110// Lookup a type according to its size, and return a value corresponding to the
111// NumPy typenum.
112template <typename Concrete>
113constexpr int platform_typeid_lookup()
114{
115 return -1;
116}
117
118template <typename Concrete, typename T, typename... Ts, typename... Ints>
119constexpr int platform_typeid_lookup(int I, Ints... Is)
120{
121 return sizeof(Concrete) == sizeof(T)
122 ? I
123 : platform_typeid_lookup<Concrete, Ts...>(Is...);
124}
125
127{
128public:
129 // dpctl type objects
130 PyTypeObject *Py_SyclDeviceType_;
131 PyTypeObject *PySyclDeviceType_;
132 PyTypeObject *Py_SyclContextType_;
133 PyTypeObject *PySyclContextType_;
134 PyTypeObject *Py_SyclEventType_;
135 PyTypeObject *PySyclEventType_;
136 PyTypeObject *Py_SyclQueueType_;
137 PyTypeObject *PySyclQueueType_;
138 PyTypeObject *Py_MemoryType_;
139 PyTypeObject *PyMemoryUSMDeviceType_;
140 PyTypeObject *PyMemoryUSMSharedType_;
141 PyTypeObject *PyMemoryUSMHostType_;
142 PyTypeObject *PyUSMArrayType_;
143 PyTypeObject *PySyclProgramType_;
144 PyTypeObject *PySyclKernelType_;
145
146 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
147 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
148
149 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
150 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
151
152 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
153 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
154
155 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
156 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
157
158 // memory
159 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
160 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
161 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
162 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
163 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
164 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
165 size_t,
166 DPCTLSyclQueueRef,
167 PyObject *);
168
169 // program
170 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
171 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
172
173 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
174 PySyclProgramObject *);
175 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
176
177 // tensor
178 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
179 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
180 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
181 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
182 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
183 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
184 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
185 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
186 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
187 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
188 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
189 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
190 const py::ssize_t *,
191 int,
192 Py_MemoryObject *,
193 py::ssize_t,
194 char);
195 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
196 int,
197 DPCTLSyclUSMRef,
198 DPCTLSyclQueueRef,
199 PyObject *);
200 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
201 const py::ssize_t *,
202 int,
203 const py::ssize_t *,
204 DPCTLSyclUSMRef,
205 DPCTLSyclQueueRef,
206 py::ssize_t,
207 PyObject *);
208
209 int USM_ARRAY_C_CONTIGUOUS_;
210 int USM_ARRAY_F_CONTIGUOUS_;
211 int USM_ARRAY_WRITABLE_;
212 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
213 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
214 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
215 UAR_HALF_;
216 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
217 UAR_INT64_, UAR_UINT64_;
218
219 bool PySyclDevice_Check_(PyObject *obj) const
220 {
221 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
222 }
223 bool PySyclContext_Check_(PyObject *obj) const
224 {
225 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
226 }
227 bool PySyclEvent_Check_(PyObject *obj) const
228 {
229 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
230 }
231 bool PySyclQueue_Check_(PyObject *obj) const
232 {
233 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
234 }
235 bool PySyclKernel_Check_(PyObject *obj) const
236 {
237 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
238 }
239 bool PySyclProgram_Check_(PyObject *obj) const
240 {
241 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
242 }
243
245 {
246 as_usm_memory_.reset();
247 default_usm_ndarray_.reset();
248 default_usm_memory_.reset();
249 default_sycl_queue_.reset();
250 };
251
252 static auto &get()
253 {
254 static dpctl_capi api{};
255 return api;
256 }
257
258 py::object default_sycl_queue_pyobj()
259 {
260 return *default_sycl_queue_;
261 }
262 py::object default_usm_memory_pyobj()
263 {
264 return *default_usm_memory_;
265 }
266 py::object default_usm_ndarray_pyobj()
267 {
268 return *default_usm_ndarray_;
269 }
270 py::object as_usm_memory_pyobj()
271 {
272 return *as_usm_memory_;
273 }
274
275private:
276 struct Deleter
277 {
278 void operator()(py::object *p) const
279 {
280 const bool initialized = Py_IsInitialized();
281#if PY_VERSION_HEX < 0x30d0000
282 const bool finalizing = _Py_IsFinalizing();
283#else
284 const bool finalizing = Py_IsFinalizing();
285#endif
286 const bool guard = initialized && !finalizing;
287
288 if (guard) {
289 delete p;
290 }
291 }
292 };
293
294 std::shared_ptr<py::object> default_sycl_queue_;
295 std::shared_ptr<py::object> default_usm_memory_;
296 std::shared_ptr<py::object> default_usm_ndarray_;
297 std::shared_ptr<py::object> as_usm_memory_;
298
299 dpctl_capi()
300 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
301 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
302 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
303 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
304 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
305 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
306 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
307 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
308 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
309 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
310 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
311 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
312 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
313 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
314 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
315 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
316 SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
317 UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
318 UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
319 UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
320 UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
321 UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
322 UsmNDArray_MakeSimpleFromMemory_(nullptr),
323 UsmNDArray_MakeSimpleFromPtr_(nullptr),
324 UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
325 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
326 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
327 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
328 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
329 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
330 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
331 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
332 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
333 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
334
335 {
336 // Import Cython-generated C-API for dpctl
337 // This imports python modules and initializes
338 // static variables such as function pointers for C-API,
339 // e.g. SyclDevice_GetDeviceRef, etc.
340 // pointers to Python types, i.e. PySyclDeviceType, etc.
341 // and exported constants, i.e. USM_ARRAY_C_CONTIGUOUS, etc.
342 import_dpctl();
343
344 // Python type objects for classes implemented by dpctl
345 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
346 this->PySyclDeviceType_ = &PySyclDeviceType;
347 this->Py_SyclContextType_ = &Py_SyclContextType;
348 this->PySyclContextType_ = &PySyclContextType;
349 this->Py_SyclEventType_ = &Py_SyclEventType;
350 this->PySyclEventType_ = &PySyclEventType;
351 this->Py_SyclQueueType_ = &Py_SyclQueueType;
352 this->PySyclQueueType_ = &PySyclQueueType;
353 this->Py_MemoryType_ = &Py_MemoryType;
354 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
355 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
356 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
357 this->PyUSMArrayType_ = &PyUSMArrayType;
358 this->PySyclProgramType_ = &PySyclProgramType;
359 this->PySyclKernelType_ = &PySyclKernelType;
360
361 // SyclDevice API
362 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
363 this->SyclDevice_Make_ = SyclDevice_Make;
364
365 // SyclContext API
366 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
367 this->SyclContext_Make_ = SyclContext_Make;
368
369 // SyclEvent API
370 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
371 this->SyclEvent_Make_ = SyclEvent_Make;
372
373 // SyclQueue API
374 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
375 this->SyclQueue_Make_ = SyclQueue_Make;
376
377 // dpctl.memory API
378 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
379 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
380 this->Memory_GetContextRef_ = Memory_GetContextRef;
381 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
382 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
383 this->Memory_Make_ = Memory_Make;
384
385 // dpctl.program API
386 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
387 this->SyclKernel_Make_ = SyclKernel_Make;
388 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
389 this->SyclProgram_Make_ = SyclProgram_Make;
390
391 // dpctl.tensor.usm_ndarray API
392 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
393 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
394 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
395 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
396 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
397 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
398 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
399 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
400 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
401 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
402 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
403 this->UsmNDArray_MakeSimpleFromMemory_ =
404 UsmNDArray_MakeSimpleFromMemory;
405 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
406 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
407
408 // constants
409 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
410 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
411 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
412 this->UAR_BOOL_ = UAR_BOOL;
413 this->UAR_BYTE_ = UAR_BYTE;
414 this->UAR_UBYTE_ = UAR_UBYTE;
415 this->UAR_SHORT_ = UAR_SHORT;
416 this->UAR_USHORT_ = UAR_USHORT;
417 this->UAR_INT_ = UAR_INT;
418 this->UAR_UINT_ = UAR_UINT;
419 this->UAR_LONG_ = UAR_LONG;
420 this->UAR_ULONG_ = UAR_ULONG;
421 this->UAR_LONGLONG_ = UAR_LONGLONG;
422 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
423 this->UAR_FLOAT_ = UAR_FLOAT;
424 this->UAR_DOUBLE_ = UAR_DOUBLE;
425 this->UAR_CFLOAT_ = UAR_CFLOAT;
426 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
427 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
428 this->UAR_HALF_ = UAR_HALF;
429
430 // deduced disjoint types
431 this->UAR_INT8_ = UAR_BYTE;
432 this->UAR_UINT8_ = UAR_UBYTE;
433 this->UAR_INT16_ = UAR_SHORT;
434 this->UAR_UINT16_ = UAR_USHORT;
435 this->UAR_INT32_ =
436 platform_typeid_lookup<std::int32_t, long, int, short>(
437 UAR_LONG, UAR_INT, UAR_SHORT);
438 this->UAR_UINT32_ =
439 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
440 unsigned short>(UAR_ULONG, UAR_UINT,
441 UAR_USHORT);
442 this->UAR_INT64_ =
443 platform_typeid_lookup<std::int64_t, long, long long, int>(
444 UAR_LONG, UAR_LONGLONG, UAR_INT);
445 this->UAR_UINT64_ =
446 platform_typeid_lookup<std::uint64_t, unsigned long,
447 unsigned long long, unsigned int>(
448 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
449
450 // create shared pointers to python objects used in type-casters
451 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
452 sycl::queue q_{};
453 PySyclQueueObject *py_q_tmp =
454 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
455 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
456 reinterpret_cast<PyObject *>(py_q_tmp));
457
458 default_sycl_queue_ = std::shared_ptr<py::object>(
459 new py::object(py_sycl_queue), Deleter{});
460
461 py::module_ mod_memory = py::module_::import("dpctl.memory");
462 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
463 as_usm_memory_ = std::shared_ptr<py::object>(
464 new py::object{py_as_usm_memory}, Deleter{});
465
466 auto mem_kl = mod_memory.attr("MemoryUSMHost");
467 const py::object &py_default_usm_memory =
468 mem_kl(1, py::arg("queue") = py_sycl_queue);
469 default_usm_memory_ = std::shared_ptr<py::object>(
470 new py::object{py_default_usm_memory}, Deleter{});
471
472 // TODO: revert to `py::module_::import("dpctl.tensor._usmarray");`
473 // when dpnp fully migrates dpctl/tensor
474 py::module_ mod_usmarray =
475 py::module_::import("dpctl_ext.tensor._usmarray");
476 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
477
478 const py::object &py_default_usm_ndarray =
479 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
480 py::arg("buffer") = py_default_usm_memory);
481
482 default_usm_ndarray_ = std::shared_ptr<py::object>(
483 new py::object{py_default_usm_ndarray}, Deleter{});
484 }
485
486 dpctl_capi(dpctl_capi const &) = default;
487 dpctl_capi &operator=(dpctl_capi const &) = default;
488 dpctl_capi &operator=(dpctl_capi &&) = default;
489
490}; // struct dpctl_capi
491} // namespace detail
492} // namespace dpctl
493
494namespace pybind11::detail
495{
496#define DPCTL_TYPE_CASTER(type, py_name) \
497protected: \
498 std::unique_ptr<type> value; \
499 \
500public: \
501 static constexpr auto name = py_name; \
502 template < \
503 typename T_, \
504 ::pybind11::detail::enable_if_t< \
505 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
506 int> = 0> \
507 static ::pybind11::handle cast(T_ *src, \
508 ::pybind11::return_value_policy policy, \
509 ::pybind11::handle parent) \
510 { \
511 if (!src) \
512 return ::pybind11::none().release(); \
513 if (policy == ::pybind11::return_value_policy::take_ownership) { \
514 auto h = cast(std::move(*src), policy, parent); \
515 delete src; \
516 return h; \
517 } \
518 return cast(*src, policy, parent); \
519 } \
520 operator type *() \
521 { \
522 return value.get(); \
523 } /* NOLINT(bugprone-macro-parentheses) */ \
524 operator type &() \
525 { \
526 return *value; \
527 } /* NOLINT(bugprone-macro-parentheses) */ \
528 operator type &&() && \
529 { \
530 return std::move(*value); \
531 } /* NOLINT(bugprone-macro-parentheses) */ \
532 template <typename T_> \
533 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
534
535/* This type caster associates ``sycl::queue`` C++ class with
536 * :class:`dpctl.SyclQueue` for the purposes of generation of
537 * Python bindings by pybind11.
538 */
539template <>
540struct type_caster<sycl::queue>
541{
542public:
543 bool load(handle src, bool)
544 {
545 PyObject *source = src.ptr();
546 auto const &api = ::dpctl::detail::dpctl_capi::get();
547 if (api.PySyclQueue_Check_(source)) {
548 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
549 reinterpret_cast<PySyclQueueObject *>(source));
550 value = std::make_unique<sycl::queue>(
551 *(reinterpret_cast<sycl::queue *>(QRef)));
552 return true;
553 }
554 else {
555 throw py::type_error(
556 "Input is of unexpected type, expected dpctl.SyclQueue");
557 }
558 }
559
560 static handle cast(sycl::queue src, return_value_policy, handle)
561 {
562 auto const &api = ::dpctl::detail::dpctl_capi::get();
563 auto tmp =
564 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
565 return handle(reinterpret_cast<PyObject *>(tmp));
566 }
567
568 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
569};
570
571/* This type caster associates ``sycl::device`` C++ class with
572 * :class:`dpctl.SyclDevice` for the purposes of generation of
573 * Python bindings by pybind11.
574 */
575template <>
576struct type_caster<sycl::device>
577{
578public:
579 bool load(handle src, bool)
580 {
581 PyObject *source = src.ptr();
582 auto const &api = ::dpctl::detail::dpctl_capi::get();
583 if (api.PySyclDevice_Check_(source)) {
584 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
585 reinterpret_cast<PySyclDeviceObject *>(source));
586 value = std::make_unique<sycl::device>(
587 *(reinterpret_cast<sycl::device *>(DRef)));
588 return true;
589 }
590 else {
591 throw py::type_error(
592 "Input is of unexpected type, expected dpctl.SyclDevice");
593 }
594 }
595
596 static handle cast(sycl::device src, return_value_policy, handle)
597 {
598 auto const &api = ::dpctl::detail::dpctl_capi::get();
599 auto tmp =
600 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
601 return handle(reinterpret_cast<PyObject *>(tmp));
602 }
603
604 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
605};
606
607/* This type caster associates ``sycl::context`` C++ class with
608 * :class:`dpctl.SyclContext` for the purposes of generation of
609 * Python bindings by pybind11.
610 */
611template <>
612struct type_caster<sycl::context>
613{
614public:
615 bool load(handle src, bool)
616 {
617 PyObject *source = src.ptr();
618 auto const &api = ::dpctl::detail::dpctl_capi::get();
619 if (api.PySyclContext_Check_(source)) {
620 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
621 reinterpret_cast<PySyclContextObject *>(source));
622 value = std::make_unique<sycl::context>(
623 *(reinterpret_cast<sycl::context *>(CRef)));
624 return true;
625 }
626 else {
627 throw py::type_error(
628 "Input is of unexpected type, expected dpctl.SyclContext");
629 }
630 }
631
632 static handle cast(sycl::context src, return_value_policy, handle)
633 {
634 auto const &api = ::dpctl::detail::dpctl_capi::get();
635 auto tmp =
636 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
637 return handle(reinterpret_cast<PyObject *>(tmp));
638 }
639
640 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
641};
642
643/* This type caster associates ``sycl::event`` C++ class with
644 * :class:`dpctl.SyclEvent` for the purposes of generation of
645 * Python bindings by pybind11.
646 */
647template <>
648struct type_caster<sycl::event>
649{
650public:
651 bool load(handle src, bool)
652 {
653 PyObject *source = src.ptr();
654 auto const &api = ::dpctl::detail::dpctl_capi::get();
655 if (api.PySyclEvent_Check_(source)) {
656 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
657 reinterpret_cast<PySyclEventObject *>(source));
658 value = std::make_unique<sycl::event>(
659 *(reinterpret_cast<sycl::event *>(ERef)));
660 return true;
661 }
662 else {
663 throw py::type_error(
664 "Input is of unexpected type, expected dpctl.SyclEvent");
665 }
666 }
667
668 static handle cast(sycl::event src, return_value_policy, handle)
669 {
670 auto const &api = ::dpctl::detail::dpctl_capi::get();
671 auto tmp =
672 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
673 return handle(reinterpret_cast<PyObject *>(tmp));
674 }
675
676 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
677};
678
679/* This type caster associates ``sycl::kernel`` C++ class with
680 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
681 * Python bindings by pybind11.
682 */
683template <>
684struct type_caster<sycl::kernel>
685{
686public:
687 bool load(handle src, bool)
688 {
689 PyObject *source = src.ptr();
690 auto const &api = ::dpctl::detail::dpctl_capi::get();
691 if (api.PySyclKernel_Check_(source)) {
692 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
693 reinterpret_cast<PySyclKernelObject *>(source));
694 value = std::make_unique<sycl::kernel>(
695 *(reinterpret_cast<sycl::kernel *>(KRef)));
696 return true;
697 }
698 else {
699 throw py::type_error("Input is of unexpected type, expected "
700 "dpctl.program.SyclKernel");
701 }
702 }
703
704 static handle cast(sycl::kernel src, return_value_policy, handle)
705 {
706 auto const &api = ::dpctl::detail::dpctl_capi::get();
707 auto tmp =
708 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
709 "dpctl4pybind11_kernel");
710 return handle(reinterpret_cast<PyObject *>(tmp));
711 }
712
713 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
714};
715
716/* This type caster associates
717 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
718 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
719 * Python bindings by pybind11.
720 */
721template <>
722struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
723{
724public:
725 bool load(handle src, bool)
726 {
727 PyObject *source = src.ptr();
728 auto const &api = ::dpctl::detail::dpctl_capi::get();
729 if (api.PySyclProgram_Check_(source)) {
730 DPCTLSyclKernelBundleRef KBRef =
731 api.SyclProgram_GetKernelBundleRef_(
732 reinterpret_cast<PySyclProgramObject *>(source));
733 value = std::make_unique<
734 sycl::kernel_bundle<sycl::bundle_state::executable>>(
735 *(reinterpret_cast<
736 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
737 KBRef)));
738 return true;
739 }
740 else {
741 throw py::type_error("Input is of unexpected type, expected "
742 "dpctl.program.SyclProgram");
743 }
744 }
745
746 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
747 return_value_policy,
748 handle)
749 {
750 auto const &api = ::dpctl::detail::dpctl_capi::get();
751 auto tmp = api.SyclProgram_Make_(
752 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
753 return handle(reinterpret_cast<PyObject *>(tmp));
754 }
755
756 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
757 _("dpctl.program.SyclProgram"));
758};
759
760/* This type caster associates
761 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
762 * of generation of Python bindings by pybind11.
763 */
764template <>
765struct type_caster<sycl::half>
766{
767public:
768 bool load(handle src, bool convert)
769 {
770 double py_value;
771
772 if (!src) {
773 return false;
774 }
775
776 PyObject *source = src.ptr();
777
778 if (convert || PyFloat_Check(source)) {
779 py_value = PyFloat_AsDouble(source);
780 }
781 else {
782 return false;
783 }
784
785 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
786
787 if (py_err) {
788 PyErr_Clear();
789 if (convert && (PyNumber_Check(source) != 0)) {
790 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
791 return load(tmp, false);
792 }
793 return false;
794 }
795 value = static_cast<sycl::half>(py_value);
796 return true;
797 }
798
799 static handle cast(sycl::half src, return_value_policy, handle)
800 {
801 return PyFloat_FromDouble(static_cast<double>(src));
802 }
803
804 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
805};
806} // namespace pybind11::detail
807
808namespace dpctl
809{
810namespace memory
811{
812// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
813// this allows to avoid compilation error
814using pybind11::error_already_set;
815
816class usm_memory : public py::object
817{
818public:
819 PYBIND11_OBJECT_CVT(
821 py::object,
822 [](PyObject *o) -> bool {
823 return PyObject_TypeCheck(
824 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
825 0;
826 },
827 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
828
829 usm_memory()
830 : py::object(
831 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
832 borrowed_t{})
833 {
834 if (!m_ptr)
835 throw py::error_already_set();
836 }
837
841 usm_memory(void *usm_ptr,
842 std::size_t nbytes,
843 const sycl::queue &q,
844 std::shared_ptr<void> shptr)
845 {
846 auto const &api = ::dpctl::detail::dpctl_capi::get();
847 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
848 auto q_uptr = std::make_unique<sycl::queue>(q);
849 DPCTLSyclQueueRef QRef =
850 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
851
852 auto vacuous_destructor = []() {};
853 py::capsule mock_owner(vacuous_destructor);
854
855 // create memory object owned by mock_owner, it is a new reference
856 PyObject *_memory =
857 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
858 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
859
860 using py_uptrT =
861 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
862
863 if (!_memory) {
864 throw py::error_already_set();
865 }
866
867 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
868 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
869
870 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
871 // replace mock_owner capsule as the owner
872 memobj->refobj = Py_None;
873 // set opaque ptr field, usm_memory now knowns that USM is managed
874 // by smart pointer
875 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
876
877 // _memory will delete created copies of sycl::queue, and
878 // std::shared_ptr and the deleter of the shared_ptr<void> is
879 // supposed to free the USM allocation
880 m_ptr = _memory;
881 q_uptr.release();
882 memory_uptr.release();
883 }
884
885 sycl::queue get_queue() const
886 {
887 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
888 auto const &api = ::dpctl::detail::dpctl_capi::get();
889 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
890 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
891 return *obj_q;
892 }
893
894 char *get_pointer() const
895 {
896 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
897 auto const &api = ::dpctl::detail::dpctl_capi::get();
898 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
899 return reinterpret_cast<char *>(MRef);
900 }
901
902 std::size_t get_nbytes() const
903 {
904 auto const &api = ::dpctl::detail::dpctl_capi::get();
905 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
906 return api.Memory_GetNumBytes_(mem_obj);
907 }
908
909 bool is_managed_by_smart_ptr() const
910 {
911 auto const &api = ::dpctl::detail::dpctl_capi::get();
912 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
913 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
914
915 return bool(opaque_ptr);
916 }
917
918 const std::shared_ptr<void> &get_smart_ptr_owner() const
919 {
920 auto const &api = ::dpctl::detail::dpctl_capi::get();
921 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
922 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
923
924 if (opaque_ptr) {
925 auto shptr_ptr =
926 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
927 return *shptr_ptr;
928 }
929 else {
930 throw std::runtime_error(
931 "Memory object does not have smart pointer "
932 "managing lifetime of USM allocation");
933 }
934 }
935
936protected:
937 static PyObject *as_usm_memory(PyObject *o)
938 {
939 if (o == nullptr) {
940 PyErr_SetString(PyExc_ValueError,
941 "cannot create a usm_memory from a nullptr");
942 return nullptr;
943 }
944
945 auto converter =
946 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
947
948 py::object res;
949 try {
950 res = converter(py::handle(o));
951 } catch (const py::error_already_set &e) {
952 return nullptr;
953 }
954 return res.ptr();
955 }
956};
957} // end namespace memory
958
959namespace tensor
960{
961inline std::vector<py::ssize_t>
962 c_contiguous_strides(int nd,
963 const py::ssize_t *shape,
964 py::ssize_t element_size = 1)
965{
966 if (nd > 0) {
967 std::vector<py::ssize_t> c_strides(nd, element_size);
968 for (int ic = nd - 1; ic > 0;) {
969 py::ssize_t next_v = c_strides[ic] * shape[ic];
970 c_strides[--ic] = next_v;
971 }
972 return c_strides;
973 }
974 else {
975 return std::vector<py::ssize_t>();
976 }
977}
978
979inline std::vector<py::ssize_t>
980 f_contiguous_strides(int nd,
981 const py::ssize_t *shape,
982 py::ssize_t element_size = 1)
983{
984 if (nd > 0) {
985 std::vector<py::ssize_t> f_strides(nd, element_size);
986 for (int i = 0; i < nd - 1;) {
987 py::ssize_t next_v = f_strides[i] * shape[i];
988 f_strides[++i] = next_v;
989 }
990 return f_strides;
991 }
992 else {
993 return std::vector<py::ssize_t>();
994 }
995}
996
997inline std::vector<py::ssize_t>
998 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
999 py::ssize_t element_size = 1)
1000{
1001 return c_contiguous_strides(shape.size(), shape.data(), element_size);
1002}
1003
1004inline std::vector<py::ssize_t>
1005 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
1006 py::ssize_t element_size = 1)
1007{
1008 return f_contiguous_strides(shape.size(), shape.data(), element_size);
1009}
1010
1011class usm_ndarray : public py::object
1012{
1013public:
1014 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
1015 return PyObject_TypeCheck(
1016 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
1017 })
1018
1019 usm_ndarray()
1020 : py::object(
1021 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
1022 borrowed_t{})
1023 {
1024 if (!m_ptr)
1025 throw py::error_already_set();
1026 }
1027
1028 char *get_data() const
1029 {
1030 PyUSMArrayObject *raw_ar = usm_array_ptr();
1031
1032 auto const &api = ::dpctl::detail::dpctl_capi::get();
1033 return api.UsmNDArray_GetData_(raw_ar);
1034 }
1035
1036 template <typename T>
1037 T *get_data() const
1038 {
1039 return reinterpret_cast<T *>(get_data());
1040 }
1041
1042 int get_ndim() const
1043 {
1044 PyUSMArrayObject *raw_ar = usm_array_ptr();
1045
1046 auto const &api = ::dpctl::detail::dpctl_capi::get();
1047 return api.UsmNDArray_GetNDim_(raw_ar);
1048 }
1049
1050 const py::ssize_t *get_shape_raw() const
1051 {
1052 PyUSMArrayObject *raw_ar = usm_array_ptr();
1053
1054 auto const &api = ::dpctl::detail::dpctl_capi::get();
1055 return api.UsmNDArray_GetShape_(raw_ar);
1056 }
1057
1058 std::vector<py::ssize_t> get_shape_vector() const
1059 {
1060 auto raw_sh = get_shape_raw();
1061 auto nd = get_ndim();
1062
1063 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
1064 return shape_vector;
1065 }
1066
1067 py::ssize_t get_shape(int i) const
1068 {
1069 auto shape_ptr = get_shape_raw();
1070 return shape_ptr[i];
1071 }
1072
1073 const py::ssize_t *get_strides_raw() const
1074 {
1075 PyUSMArrayObject *raw_ar = usm_array_ptr();
1076
1077 auto const &api = ::dpctl::detail::dpctl_capi::get();
1078 return api.UsmNDArray_GetStrides_(raw_ar);
1079 }
1080
1081 std::vector<py::ssize_t> get_strides_vector() const
1082 {
1083 auto raw_st = get_strides_raw();
1084 auto nd = get_ndim();
1085
1086 if (raw_st == nullptr) {
1087 auto is_c_contig = is_c_contiguous();
1088 auto is_f_contig = is_f_contiguous();
1089 auto raw_sh = get_shape_raw();
1090 if (is_c_contig) {
1091 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1092 return contig_strides;
1093 }
1094 else if (is_f_contig) {
1095 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1096 return contig_strides;
1097 }
1098 else {
1099 throw std::runtime_error("Invalid array encountered when "
1100 "building strides");
1101 }
1102 }
1103 else {
1104 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1105 return st_vec;
1106 }
1107 }
1108
1109 py::ssize_t get_size() const
1110 {
1111 PyUSMArrayObject *raw_ar = usm_array_ptr();
1112
1113 auto const &api = ::dpctl::detail::dpctl_capi::get();
1114 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
1115 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1116
1117 py::ssize_t nelems = 1;
1118 for (int i = 0; i < ndim; ++i) {
1119 nelems *= shape[i];
1120 }
1121
1122 assert(nelems >= 0);
1123 return nelems;
1124 }
1125
1126 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1127 {
1128 PyUSMArrayObject *raw_ar = usm_array_ptr();
1129
1130 auto const &api = ::dpctl::detail::dpctl_capi::get();
1131 int nd = api.UsmNDArray_GetNDim_(raw_ar);
1132 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
1133 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
1134
1135 py::ssize_t offset_min = 0;
1136 py::ssize_t offset_max = 0;
1137 if (strides == nullptr) {
1138 py::ssize_t stride(1);
1139 for (int i = 0; i < nd; ++i) {
1140 offset_max += stride * (shape[i] - 1);
1141 stride *= shape[i];
1142 }
1143 }
1144 else {
1145 for (int i = 0; i < nd; ++i) {
1146 py::ssize_t delta = strides[i] * (shape[i] - 1);
1147 if (strides[i] > 0) {
1148 offset_max += delta;
1149 }
1150 else {
1151 offset_min += delta;
1152 }
1153 }
1154 }
1155 return std::make_pair(offset_min, offset_max);
1156 }
1157
1158 sycl::queue get_queue() const
1159 {
1160 PyUSMArrayObject *raw_ar = usm_array_ptr();
1161
1162 auto const &api = ::dpctl::detail::dpctl_capi::get();
1163 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1164 return *(reinterpret_cast<sycl::queue *>(QRef));
1165 }
1166
1167 sycl::device get_device() const
1168 {
1169 PyUSMArrayObject *raw_ar = usm_array_ptr();
1170
1171 auto const &api = ::dpctl::detail::dpctl_capi::get();
1172 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
1173 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1174 }
1175
1176 int get_typenum() const
1177 {
1178 PyUSMArrayObject *raw_ar = usm_array_ptr();
1179
1180 auto const &api = ::dpctl::detail::dpctl_capi::get();
1181 return api.UsmNDArray_GetTypenum_(raw_ar);
1182 }
1183
1184 int get_flags() const
1185 {
1186 PyUSMArrayObject *raw_ar = usm_array_ptr();
1187
1188 auto const &api = ::dpctl::detail::dpctl_capi::get();
1189 return api.UsmNDArray_GetFlags_(raw_ar);
1190 }
1191
1192 int get_elemsize() const
1193 {
1194 PyUSMArrayObject *raw_ar = usm_array_ptr();
1195
1196 auto const &api = ::dpctl::detail::dpctl_capi::get();
1197 return api.UsmNDArray_GetElementSize_(raw_ar);
1198 }
1199
1200 bool is_c_contiguous() const
1201 {
1202 int flags = get_flags();
1203 auto const &api = ::dpctl::detail::dpctl_capi::get();
1204 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1205 }
1206
1207 bool is_f_contiguous() const
1208 {
1209 int flags = get_flags();
1210 auto const &api = ::dpctl::detail::dpctl_capi::get();
1211 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1212 }
1213
1214 bool is_writable() const
1215 {
1216 int flags = get_flags();
1217 auto const &api = ::dpctl::detail::dpctl_capi::get();
1218 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1219 }
1220
1222 py::object get_usm_data() const
1223 {
1224 PyUSMArrayObject *raw_ar = usm_array_ptr();
1225
1226 auto const &api = ::dpctl::detail::dpctl_capi::get();
1227 // UsmNDArray_GetUSMData_ gives a new reference
1228 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1229
1230 // pass reference ownership to py::object
1231 return py::reinterpret_steal<py::object>(usm_data);
1232 }
1233
1234 bool is_managed_by_smart_ptr() const
1235 {
1236 PyUSMArrayObject *raw_ar = usm_array_ptr();
1237
1238 auto const &api = ::dpctl::detail::dpctl_capi::get();
1239 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1240
1241 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1242 Py_DECREF(usm_data);
1243 return false;
1244 }
1245
1246 Py_MemoryObject *mem_obj =
1247 reinterpret_cast<Py_MemoryObject *>(usm_data);
1248 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1249
1250 Py_DECREF(usm_data);
1251 return bool(opaque_ptr);
1252 }
1253
1254 const std::shared_ptr<void> &get_smart_ptr_owner() const
1255 {
1256 PyUSMArrayObject *raw_ar = usm_array_ptr();
1257
1258 auto const &api = ::dpctl::detail::dpctl_capi::get();
1259
1260 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1261
1262 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1263 Py_DECREF(usm_data);
1264 throw std::runtime_error(
1265 "usm_ndarray object does not have Memory object "
1266 "managing lifetime of USM allocation");
1267 }
1268
1269 Py_MemoryObject *mem_obj =
1270 reinterpret_cast<Py_MemoryObject *>(usm_data);
1271 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1272 Py_DECREF(usm_data);
1273
1274 if (opaque_ptr) {
1275 auto shptr_ptr =
1276 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1277 return *shptr_ptr;
1278 }
1279 else {
1280 throw std::runtime_error(
1281 "Memory object underlying usm_ndarray does not have "
1282 "smart pointer managing lifetime of USM allocation");
1283 }
1284 }
1285
1286private:
1287 PyUSMArrayObject *usm_array_ptr() const
1288 {
1289 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1290 }
1291};
1292} // end namespace tensor
1293
1294namespace utils
1295{
1296namespace detail
1297{
1299{
1300
1301 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1302 {
1303 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1304 const auto &usm_memory_inst =
1305 py::cast<dpctl::memory::usm_memory>(h);
1306 return usm_memory_inst.is_managed_by_smart_ptr();
1307 }
1308 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1309 const auto &usm_array_inst =
1310 py::cast<dpctl::tensor::usm_ndarray>(h);
1311 return usm_array_inst.is_managed_by_smart_ptr();
1312 }
1313
1314 return false;
1315 }
1316
1317 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1318 {
1319 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1320 const auto &usm_memory_inst =
1321 py::cast<dpctl::memory::usm_memory>(h);
1322 return usm_memory_inst.get_smart_ptr_owner();
1323 }
1324 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1325 const auto &usm_array_inst =
1326 py::cast<dpctl::tensor::usm_ndarray>(h);
1327 return usm_array_inst.get_smart_ptr_owner();
1328 }
1329
1330 throw std::runtime_error(
1331 "Attempted extraction of shared_ptr on an unrecognized type");
1332 }
1333};
1334} // end of namespace detail
1335
1336template <std::size_t num>
1337sycl::event keep_args_alive(sycl::queue &q,
1338 const py::object (&py_objs)[num],
1339 const std::vector<sycl::event> &depends = {})
1340{
1341 std::size_t n_objects_held = 0;
1342 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1343
1344 std::size_t n_usm_owners_held = 0;
1345 std::array<std::shared_ptr<void>, num> shp_usm{};
1346
1347 for (std::size_t i = 0; i < num; ++i) {
1348 const auto &py_obj_i = py_objs[i];
1349 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1350 const auto &shp =
1351 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1352 shp_usm[n_usm_owners_held] = shp;
1353 ++n_usm_owners_held;
1354 }
1355 else {
1356 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1357 shp_arr[n_objects_held]->inc_ref();
1358 ++n_objects_held;
1359 }
1360 }
1361
1362 bool use_depends = true;
1363 sycl::event host_task_ev;
1364
1365 if (n_usm_owners_held > 0) {
1366 host_task_ev = q.submit([&](sycl::handler &cgh) {
1367 if (use_depends) {
1368 cgh.depends_on(depends);
1369 use_depends = false;
1370 }
1371 else {
1372 cgh.depends_on(host_task_ev);
1373 }
1374 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1375 // no body, but shared pointers are captured in
1376 // the lambda, ensuring that USM allocation is
1377 // kept alive
1378 });
1379 });
1380 }
1381
1382 if (n_objects_held > 0) {
1383 host_task_ev = q.submit([&](sycl::handler &cgh) {
1384 if (use_depends) {
1385 cgh.depends_on(depends);
1386 use_depends = false;
1387 }
1388 else {
1389 cgh.depends_on(host_task_ev);
1390 }
1391 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1392 py::gil_scoped_acquire acquire;
1393
1394 for (std::size_t i = 0; i < n_objects_held; ++i) {
1395 shp_arr[i]->dec_ref();
1396 }
1397 });
1398 });
1399 }
1400
1401 return host_task_ev;
1402}
1403
1406template <std::size_t num>
1407bool queues_are_compatible(const sycl::queue &exec_q,
1408 const sycl::queue (&alloc_qs)[num])
1409{
1410 for (std::size_t i = 0; i < num; ++i) {
1411
1412 if (exec_q != alloc_qs[i]) {
1413 return false;
1414 }
1415 }
1416 return true;
1417}
1418
1421template <std::size_t num>
1422bool queues_are_compatible(const sycl::queue &exec_q,
1423 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1424{
1425 for (std::size_t i = 0; i < num; ++i) {
1426
1427 if (exec_q != arrs[i].get_queue()) {
1428 return false;
1429 }
1430 }
1431 return true;
1432}
1433} // end namespace utils
1434} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.