DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpctl SYCL interface from external dpctl package
32#include "syclinterface/dpctl_sycl_extension_interface.h"
33#include "syclinterface/dpctl_sycl_types.h"
34
35#ifdef __cplusplus
36#define CYTHON_EXTERN_C extern "C"
37#else
38#define CYTHON_EXTERN_C
39#endif
40
41// Include dpctl C-API headers (both declarations and import functions)
42#include "dpctl/_sycl_context.h"
43#include "dpctl/_sycl_context_api.h"
44#include "dpctl/_sycl_device.h"
45#include "dpctl/_sycl_device_api.h"
46#include "dpctl/_sycl_event.h"
47#include "dpctl/_sycl_event_api.h"
48#include "dpctl/_sycl_queue.h"
49#include "dpctl/_sycl_queue_api.h"
50#include "dpctl/memory/_memory.h"
51#include "dpctl/memory/_memory_api.h"
52#include "dpctl/program/_program.h"
53#include "dpctl/program/_program_api.h"
54
55// Include generated Cython headers for usm_ndarray struct definition and C-API
56#include "dpctl_ext/tensor/_usmarray.h"
57#include "dpctl_ext/tensor/_usmarray_api.h"
58
59#include <array>
60#include <cassert>
61#include <complex>
62#include <cstddef> // for std::size_t for C++ linkage
63#include <cstdint>
64#include <memory>
65#include <stddef.h> // for size_t for C linkage
66#include <stdexcept>
67#include <type_traits>
68#include <utility>
69#include <vector>
70
71#include <pybind11/pybind11.h>
72
73#include <sycl/sycl.hpp>
74
75namespace py = pybind11;
76
77namespace dpctl
78{
79namespace detail
80{
81// Lookup a type according to its size, and return a value corresponding to the
82// NumPy typenum.
83template <typename Concrete>
84constexpr int platform_typeid_lookup()
85{
86 return -1;
87}
88
89template <typename Concrete, typename T, typename... Ts, typename... Ints>
90constexpr int platform_typeid_lookup(int I, Ints... Is)
91{
92 return sizeof(Concrete) == sizeof(T)
93 ? I
94 : platform_typeid_lookup<Concrete, Ts...>(Is...);
95}
96
98{
99public:
100 // dpctl type objects
101 PyTypeObject *Py_SyclDeviceType_;
102 PyTypeObject *PySyclDeviceType_;
103 PyTypeObject *Py_SyclContextType_;
104 PyTypeObject *PySyclContextType_;
105 PyTypeObject *Py_SyclEventType_;
106 PyTypeObject *PySyclEventType_;
107 PyTypeObject *Py_SyclQueueType_;
108 PyTypeObject *PySyclQueueType_;
109 PyTypeObject *Py_MemoryType_;
110 PyTypeObject *PyMemoryUSMDeviceType_;
111 PyTypeObject *PyMemoryUSMSharedType_;
112 PyTypeObject *PyMemoryUSMHostType_;
113 PyTypeObject *PyUSMArrayType_;
114 PyTypeObject *PySyclProgramType_;
115 PyTypeObject *PySyclKernelType_;
116
117 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
118 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
119
120 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
121 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
122
123 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
124 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
125
126 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
127 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
128
129 // memory
130 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
131 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
132 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
133 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
134 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
135 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
136 size_t,
137 DPCTLSyclQueueRef,
138 PyObject *);
139
140 // program
141 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
142 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
143
144 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
145 PySyclProgramObject *);
146 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
147
148 int USM_ARRAY_C_CONTIGUOUS_;
149 int USM_ARRAY_F_CONTIGUOUS_;
150 int USM_ARRAY_WRITABLE_;
151 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
152 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
153 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
154 UAR_HALF_;
155 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
156 UAR_INT64_, UAR_UINT64_;
157
158 bool PySyclDevice_Check_(PyObject *obj) const
159 {
160 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
161 }
162 bool PySyclContext_Check_(PyObject *obj) const
163 {
164 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
165 }
166 bool PySyclEvent_Check_(PyObject *obj) const
167 {
168 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
169 }
170 bool PySyclQueue_Check_(PyObject *obj) const
171 {
172 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
173 }
174 bool PySyclKernel_Check_(PyObject *obj) const
175 {
176 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
177 }
178 bool PySyclProgram_Check_(PyObject *obj) const
179 {
180 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
181 }
182
184 {
185 as_usm_memory_.reset();
186 default_usm_ndarray_.reset();
187 default_usm_memory_.reset();
188 default_sycl_queue_.reset();
189 };
190
191 static auto &get()
192 {
193 static dpctl_capi api{};
194 return api;
195 }
196
197 py::object default_sycl_queue_pyobj()
198 {
199 return *default_sycl_queue_;
200 }
201 py::object default_usm_memory_pyobj()
202 {
203 return *default_usm_memory_;
204 }
205 py::object default_usm_ndarray_pyobj()
206 {
207 return *default_usm_ndarray_;
208 }
209 py::object as_usm_memory_pyobj()
210 {
211 return *as_usm_memory_;
212 }
213
214private:
215 struct Deleter
216 {
217 void operator()(py::object *p) const
218 {
219 const bool initialized = Py_IsInitialized();
220#if PY_VERSION_HEX < 0x30d0000
221 const bool finalizing = _Py_IsFinalizing();
222#else
223 const bool finalizing = Py_IsFinalizing();
224#endif
225 const bool guard = initialized && !finalizing;
226
227 if (guard) {
228 delete p;
229 }
230 }
231 };
232
233 std::shared_ptr<py::object> default_sycl_queue_;
234 std::shared_ptr<py::object> default_usm_memory_;
235 std::shared_ptr<py::object> default_usm_ndarray_;
236 std::shared_ptr<py::object> as_usm_memory_;
237
238 dpctl_capi()
239 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
240 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
241 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
242 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
243 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
244 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
245 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
246 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
247 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
248 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
249 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
250 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
251 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
252 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
253 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
254 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
255 SyclProgram_Make_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
256 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
257 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
258 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
259 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
260 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
261 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
262 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
263 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
264 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
265
266 {
267 // Import dpctl SYCL interface modules
268 // This imports python modules and initializes pointers to Python types
269 import_dpctl___sycl_device();
270 import_dpctl___sycl_context();
271 import_dpctl___sycl_event();
272 import_dpctl___sycl_queue();
273 import_dpctl__memory___memory();
274 import_dpctl__program___program();
275 // Import dpctl_ext tensor module for PyUSMArrayType
276 import_dpctl_ext__tensor___usmarray();
277
278 // Python type objects for classes implemented by dpctl
279 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
280 this->PySyclDeviceType_ = &PySyclDeviceType;
281 this->Py_SyclContextType_ = &Py_SyclContextType;
282 this->PySyclContextType_ = &PySyclContextType;
283 this->Py_SyclEventType_ = &Py_SyclEventType;
284 this->PySyclEventType_ = &PySyclEventType;
285 this->Py_SyclQueueType_ = &Py_SyclQueueType;
286 this->PySyclQueueType_ = &PySyclQueueType;
287 this->Py_MemoryType_ = &Py_MemoryType;
288 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
289 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
290 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
291 this->PyUSMArrayType_ = &PyUSMArrayType;
292 this->PySyclProgramType_ = &PySyclProgramType;
293 this->PySyclKernelType_ = &PySyclKernelType;
294
295 // SyclDevice API
296 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
297 this->SyclDevice_Make_ = SyclDevice_Make;
298
299 // SyclContext API
300 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
301 this->SyclContext_Make_ = SyclContext_Make;
302
303 // SyclEvent API
304 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
305 this->SyclEvent_Make_ = SyclEvent_Make;
306
307 // SyclQueue API
308 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
309 this->SyclQueue_Make_ = SyclQueue_Make;
310
311 // dpctl.memory API
312 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
313 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
314 this->Memory_GetContextRef_ = Memory_GetContextRef;
315 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
316 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
317 this->Memory_Make_ = Memory_Make;
318
319 // dpctl.program API
320 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
321 this->SyclKernel_Make_ = SyclKernel_Make;
322 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
323 this->SyclProgram_Make_ = SyclProgram_Make;
324
325 // constants
326 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
327 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
328 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
329 this->UAR_BOOL_ = UAR_BOOL;
330 this->UAR_BYTE_ = UAR_BYTE;
331 this->UAR_UBYTE_ = UAR_UBYTE;
332 this->UAR_SHORT_ = UAR_SHORT;
333 this->UAR_USHORT_ = UAR_USHORT;
334 this->UAR_INT_ = UAR_INT;
335 this->UAR_UINT_ = UAR_UINT;
336 this->UAR_LONG_ = UAR_LONG;
337 this->UAR_ULONG_ = UAR_ULONG;
338 this->UAR_LONGLONG_ = UAR_LONGLONG;
339 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
340 this->UAR_FLOAT_ = UAR_FLOAT;
341 this->UAR_DOUBLE_ = UAR_DOUBLE;
342 this->UAR_CFLOAT_ = UAR_CFLOAT;
343 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
344 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
345 this->UAR_HALF_ = UAR_HALF;
346
347 // deduced disjoint types
348 this->UAR_INT8_ = UAR_BYTE;
349 this->UAR_UINT8_ = UAR_UBYTE;
350 this->UAR_INT16_ = UAR_SHORT;
351 this->UAR_UINT16_ = UAR_USHORT;
352 this->UAR_INT32_ =
353 platform_typeid_lookup<std::int32_t, long, int, short>(
354 UAR_LONG, UAR_INT, UAR_SHORT);
355 this->UAR_UINT32_ =
356 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
357 unsigned short>(UAR_ULONG, UAR_UINT,
358 UAR_USHORT);
359 this->UAR_INT64_ =
360 platform_typeid_lookup<std::int64_t, long, long long, int>(
361 UAR_LONG, UAR_LONGLONG, UAR_INT);
362 this->UAR_UINT64_ =
363 platform_typeid_lookup<std::uint64_t, unsigned long,
364 unsigned long long, unsigned int>(
365 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
366
367 // create shared pointers to python objects used in type-casters
368 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
369 sycl::queue q_{};
370 PySyclQueueObject *py_q_tmp =
371 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
372 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
373 reinterpret_cast<PyObject *>(py_q_tmp));
374
375 default_sycl_queue_ = std::shared_ptr<py::object>(
376 new py::object(py_sycl_queue), Deleter{});
377
378 py::module_ mod_memory = py::module_::import("dpctl.memory");
379 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
380 as_usm_memory_ = std::shared_ptr<py::object>(
381 new py::object{py_as_usm_memory}, Deleter{});
382
383 auto mem_kl = mod_memory.attr("MemoryUSMHost");
384 const py::object &py_default_usm_memory =
385 mem_kl(1, py::arg("queue") = py_sycl_queue);
386 default_usm_memory_ = std::shared_ptr<py::object>(
387 new py::object{py_default_usm_memory}, Deleter{});
388
389 // TODO: revert to `py::module_::import("dpctl.tensor._usmarray");`
390 // when dpnp fully migrates dpctl/tensor
391 py::module_ mod_usmarray =
392 py::module_::import("dpctl_ext.tensor._usmarray");
393 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
394
395 const py::object &py_default_usm_ndarray =
396 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
397 py::arg("buffer") = py_default_usm_memory);
398
399 default_usm_ndarray_ = std::shared_ptr<py::object>(
400 new py::object{py_default_usm_ndarray}, Deleter{});
401 }
402
403 dpctl_capi(dpctl_capi const &) = default;
404 dpctl_capi &operator=(dpctl_capi const &) = default;
405 dpctl_capi &operator=(dpctl_capi &&) = default;
406
407}; // struct dpctl_capi
408} // namespace detail
409} // namespace dpctl
410
411namespace pybind11::detail
412{
413#define DPCTL_TYPE_CASTER(type, py_name) \
414protected: \
415 std::unique_ptr<type> value; \
416 \
417public: \
418 static constexpr auto name = py_name; \
419 template < \
420 typename T_, \
421 ::pybind11::detail::enable_if_t< \
422 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
423 int> = 0> \
424 static ::pybind11::handle cast(T_ *src, \
425 ::pybind11::return_value_policy policy, \
426 ::pybind11::handle parent) \
427 { \
428 if (!src) \
429 return ::pybind11::none().release(); \
430 if (policy == ::pybind11::return_value_policy::take_ownership) { \
431 auto h = cast(std::move(*src), policy, parent); \
432 delete src; \
433 return h; \
434 } \
435 return cast(*src, policy, parent); \
436 } \
437 operator type *() \
438 { \
439 return value.get(); \
440 } /* NOLINT(bugprone-macro-parentheses) */ \
441 operator type &() \
442 { \
443 return *value; \
444 } /* NOLINT(bugprone-macro-parentheses) */ \
445 operator type &&() && \
446 { \
447 return std::move(*value); \
448 } /* NOLINT(bugprone-macro-parentheses) */ \
449 template <typename T_> \
450 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
451
452/* This type caster associates ``sycl::queue`` C++ class with
453 * :class:`dpctl.SyclQueue` for the purposes of generation of
454 * Python bindings by pybind11.
455 */
456template <>
457struct type_caster<sycl::queue>
458{
459public:
460 bool load(handle src, bool)
461 {
462 PyObject *source = src.ptr();
463 auto const &api = ::dpctl::detail::dpctl_capi::get();
464 if (api.PySyclQueue_Check_(source)) {
465 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
466 reinterpret_cast<PySyclQueueObject *>(source));
467 value = std::make_unique<sycl::queue>(
468 *(reinterpret_cast<sycl::queue *>(QRef)));
469 return true;
470 }
471 else {
472 throw py::type_error(
473 "Input is of unexpected type, expected dpctl.SyclQueue");
474 }
475 }
476
477 static handle cast(sycl::queue src, return_value_policy, handle)
478 {
479 auto const &api = ::dpctl::detail::dpctl_capi::get();
480 auto tmp =
481 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
482 return handle(reinterpret_cast<PyObject *>(tmp));
483 }
484
485 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
486};
487
488/* This type caster associates ``sycl::device`` C++ class with
489 * :class:`dpctl.SyclDevice` for the purposes of generation of
490 * Python bindings by pybind11.
491 */
492template <>
493struct type_caster<sycl::device>
494{
495public:
496 bool load(handle src, bool)
497 {
498 PyObject *source = src.ptr();
499 auto const &api = ::dpctl::detail::dpctl_capi::get();
500 if (api.PySyclDevice_Check_(source)) {
501 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
502 reinterpret_cast<PySyclDeviceObject *>(source));
503 value = std::make_unique<sycl::device>(
504 *(reinterpret_cast<sycl::device *>(DRef)));
505 return true;
506 }
507 else {
508 throw py::type_error(
509 "Input is of unexpected type, expected dpctl.SyclDevice");
510 }
511 }
512
513 static handle cast(sycl::device src, return_value_policy, handle)
514 {
515 auto const &api = ::dpctl::detail::dpctl_capi::get();
516 auto tmp =
517 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
518 return handle(reinterpret_cast<PyObject *>(tmp));
519 }
520
521 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
522};
523
524/* This type caster associates ``sycl::context`` C++ class with
525 * :class:`dpctl.SyclContext` for the purposes of generation of
526 * Python bindings by pybind11.
527 */
528template <>
529struct type_caster<sycl::context>
530{
531public:
532 bool load(handle src, bool)
533 {
534 PyObject *source = src.ptr();
535 auto const &api = ::dpctl::detail::dpctl_capi::get();
536 if (api.PySyclContext_Check_(source)) {
537 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
538 reinterpret_cast<PySyclContextObject *>(source));
539 value = std::make_unique<sycl::context>(
540 *(reinterpret_cast<sycl::context *>(CRef)));
541 return true;
542 }
543 else {
544 throw py::type_error(
545 "Input is of unexpected type, expected dpctl.SyclContext");
546 }
547 }
548
549 static handle cast(sycl::context src, return_value_policy, handle)
550 {
551 auto const &api = ::dpctl::detail::dpctl_capi::get();
552 auto tmp =
553 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
554 return handle(reinterpret_cast<PyObject *>(tmp));
555 }
556
557 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
558};
559
560/* This type caster associates ``sycl::event`` C++ class with
561 * :class:`dpctl.SyclEvent` for the purposes of generation of
562 * Python bindings by pybind11.
563 */
564template <>
565struct type_caster<sycl::event>
566{
567public:
568 bool load(handle src, bool)
569 {
570 PyObject *source = src.ptr();
571 auto const &api = ::dpctl::detail::dpctl_capi::get();
572 if (api.PySyclEvent_Check_(source)) {
573 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
574 reinterpret_cast<PySyclEventObject *>(source));
575 value = std::make_unique<sycl::event>(
576 *(reinterpret_cast<sycl::event *>(ERef)));
577 return true;
578 }
579 else {
580 throw py::type_error(
581 "Input is of unexpected type, expected dpctl.SyclEvent");
582 }
583 }
584
585 static handle cast(sycl::event src, return_value_policy, handle)
586 {
587 auto const &api = ::dpctl::detail::dpctl_capi::get();
588 auto tmp =
589 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
590 return handle(reinterpret_cast<PyObject *>(tmp));
591 }
592
593 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
594};
595
596/* This type caster associates ``sycl::kernel`` C++ class with
597 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
598 * Python bindings by pybind11.
599 */
600template <>
601struct type_caster<sycl::kernel>
602{
603public:
604 bool load(handle src, bool)
605 {
606 PyObject *source = src.ptr();
607 auto const &api = ::dpctl::detail::dpctl_capi::get();
608 if (api.PySyclKernel_Check_(source)) {
609 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
610 reinterpret_cast<PySyclKernelObject *>(source));
611 value = std::make_unique<sycl::kernel>(
612 *(reinterpret_cast<sycl::kernel *>(KRef)));
613 return true;
614 }
615 else {
616 throw py::type_error("Input is of unexpected type, expected "
617 "dpctl.program.SyclKernel");
618 }
619 }
620
621 static handle cast(sycl::kernel src, return_value_policy, handle)
622 {
623 auto const &api = ::dpctl::detail::dpctl_capi::get();
624 auto tmp =
625 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
626 "dpctl4pybind11_kernel");
627 return handle(reinterpret_cast<PyObject *>(tmp));
628 }
629
630 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
631};
632
633/* This type caster associates
634 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
635 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
636 * Python bindings by pybind11.
637 */
638template <>
639struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
640{
641public:
642 bool load(handle src, bool)
643 {
644 PyObject *source = src.ptr();
645 auto const &api = ::dpctl::detail::dpctl_capi::get();
646 if (api.PySyclProgram_Check_(source)) {
647 DPCTLSyclKernelBundleRef KBRef =
648 api.SyclProgram_GetKernelBundleRef_(
649 reinterpret_cast<PySyclProgramObject *>(source));
650 value = std::make_unique<
651 sycl::kernel_bundle<sycl::bundle_state::executable>>(
652 *(reinterpret_cast<
653 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
654 KBRef)));
655 return true;
656 }
657 else {
658 throw py::type_error("Input is of unexpected type, expected "
659 "dpctl.program.SyclProgram");
660 }
661 }
662
663 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
664 return_value_policy,
665 handle)
666 {
667 auto const &api = ::dpctl::detail::dpctl_capi::get();
668 auto tmp = api.SyclProgram_Make_(
669 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
670 return handle(reinterpret_cast<PyObject *>(tmp));
671 }
672
673 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
674 _("dpctl.program.SyclProgram"));
675};
676
677/* This type caster associates
678 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
679 * of generation of Python bindings by pybind11.
680 */
681template <>
682struct type_caster<sycl::half>
683{
684public:
685 bool load(handle src, bool convert)
686 {
687 double py_value;
688
689 if (!src) {
690 return false;
691 }
692
693 PyObject *source = src.ptr();
694
695 if (convert || PyFloat_Check(source)) {
696 py_value = PyFloat_AsDouble(source);
697 }
698 else {
699 return false;
700 }
701
702 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
703
704 if (py_err) {
705 PyErr_Clear();
706 if (convert && (PyNumber_Check(source) != 0)) {
707 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
708 return load(tmp, false);
709 }
710 return false;
711 }
712 value = static_cast<sycl::half>(py_value);
713 return true;
714 }
715
716 static handle cast(sycl::half src, return_value_policy, handle)
717 {
718 return PyFloat_FromDouble(static_cast<double>(src));
719 }
720
721 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
722};
723} // namespace pybind11::detail
724
725namespace dpctl
726{
727namespace memory
728{
729// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
730// this allows to avoid compilation error
731using pybind11::error_already_set;
732
733class usm_memory : public py::object
734{
735public:
736 PYBIND11_OBJECT_CVT(
738 py::object,
739 [](PyObject *o) -> bool {
740 return PyObject_TypeCheck(
741 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
742 0;
743 },
744 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
745
746 usm_memory()
747 : py::object(
748 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
749 borrowed_t{})
750 {
751 if (!m_ptr)
752 throw py::error_already_set();
753 }
754
758 usm_memory(void *usm_ptr,
759 std::size_t nbytes,
760 const sycl::queue &q,
761 std::shared_ptr<void> shptr)
762 {
763 auto const &api = ::dpctl::detail::dpctl_capi::get();
764 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
765 auto q_uptr = std::make_unique<sycl::queue>(q);
766 DPCTLSyclQueueRef QRef =
767 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
768
769 auto vacuous_destructor = []() {};
770 py::capsule mock_owner(vacuous_destructor);
771
772 // create memory object owned by mock_owner, it is a new reference
773 PyObject *_memory =
774 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
775 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
776
777 using py_uptrT =
778 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
779
780 if (!_memory) {
781 throw py::error_already_set();
782 }
783
784 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
785 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
786
787 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
788 // replace mock_owner capsule as the owner
789 memobj->refobj = Py_None;
790 // set opaque ptr field, usm_memory now knowns that USM is managed
791 // by smart pointer
792 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
793
794 // _memory will delete created copies of sycl::queue, and
795 // std::shared_ptr and the deleter of the shared_ptr<void> is
796 // supposed to free the USM allocation
797 m_ptr = _memory;
798 q_uptr.release();
799 memory_uptr.release();
800 }
801
802 sycl::queue get_queue() const
803 {
804 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
805 auto const &api = ::dpctl::detail::dpctl_capi::get();
806 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
807 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
808 return *obj_q;
809 }
810
811 char *get_pointer() const
812 {
813 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
814 auto const &api = ::dpctl::detail::dpctl_capi::get();
815 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
816 return reinterpret_cast<char *>(MRef);
817 }
818
819 std::size_t get_nbytes() const
820 {
821 auto const &api = ::dpctl::detail::dpctl_capi::get();
822 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
823 return api.Memory_GetNumBytes_(mem_obj);
824 }
825
826 bool is_managed_by_smart_ptr() const
827 {
828 auto const &api = ::dpctl::detail::dpctl_capi::get();
829 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
830 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
831
832 return bool(opaque_ptr);
833 }
834
835 const std::shared_ptr<void> &get_smart_ptr_owner() const
836 {
837 auto const &api = ::dpctl::detail::dpctl_capi::get();
838 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
839 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
840
841 if (opaque_ptr) {
842 auto shptr_ptr =
843 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
844 return *shptr_ptr;
845 }
846 else {
847 throw std::runtime_error(
848 "Memory object does not have smart pointer "
849 "managing lifetime of USM allocation");
850 }
851 }
852
853protected:
854 static PyObject *as_usm_memory(PyObject *o)
855 {
856 if (o == nullptr) {
857 PyErr_SetString(PyExc_ValueError,
858 "cannot create a usm_memory from a nullptr");
859 return nullptr;
860 }
861
862 auto converter =
863 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
864
865 py::object res;
866 try {
867 res = converter(py::handle(o));
868 } catch (const py::error_already_set &e) {
869 return nullptr;
870 }
871 return res.ptr();
872 }
873};
874} // end namespace memory
875
876namespace tensor
877{
878inline std::vector<py::ssize_t>
879 c_contiguous_strides(int nd,
880 const py::ssize_t *shape,
881 py::ssize_t element_size = 1)
882{
883 if (nd > 0) {
884 std::vector<py::ssize_t> c_strides(nd, element_size);
885 for (int ic = nd - 1; ic > 0;) {
886 py::ssize_t next_v = c_strides[ic] * shape[ic];
887 c_strides[--ic] = next_v;
888 }
889 return c_strides;
890 }
891 else {
892 return std::vector<py::ssize_t>();
893 }
894}
895
896inline std::vector<py::ssize_t>
897 f_contiguous_strides(int nd,
898 const py::ssize_t *shape,
899 py::ssize_t element_size = 1)
900{
901 if (nd > 0) {
902 std::vector<py::ssize_t> f_strides(nd, element_size);
903 for (int i = 0; i < nd - 1;) {
904 py::ssize_t next_v = f_strides[i] * shape[i];
905 f_strides[++i] = next_v;
906 }
907 return f_strides;
908 }
909 else {
910 return std::vector<py::ssize_t>();
911 }
912}
913
914inline std::vector<py::ssize_t>
915 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
916 py::ssize_t element_size = 1)
917{
918 return c_contiguous_strides(shape.size(), shape.data(), element_size);
919}
920
921inline std::vector<py::ssize_t>
922 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
923 py::ssize_t element_size = 1)
924{
925 return f_contiguous_strides(shape.size(), shape.data(), element_size);
926}
927
928class usm_ndarray : public py::object
929{
930public:
931 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
932 return PyObject_TypeCheck(
933 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
934 })
935
937 : py::object(
938 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
939 borrowed_t{})
940 {
941 if (!m_ptr)
942 throw py::error_already_set();
943 }
944
945 char *get_data() const
946 {
947 PyUSMArrayObject *raw_ar = usm_array_ptr();
948 return raw_ar->data_;
949 }
950
951 template <typename T>
952 T *get_data() const
953 {
954 return reinterpret_cast<T *>(get_data());
955 }
956
957 int get_ndim() const
958 {
959 PyUSMArrayObject *raw_ar = usm_array_ptr();
960 return raw_ar->nd_;
961 }
962
963 const py::ssize_t *get_shape_raw() const
964 {
965 PyUSMArrayObject *raw_ar = usm_array_ptr();
966 return raw_ar->shape_;
967 }
968
969 std::vector<py::ssize_t> get_shape_vector() const
970 {
971 auto raw_sh = get_shape_raw();
972 auto nd = get_ndim();
973
974 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
975 return shape_vector;
976 }
977
978 py::ssize_t get_shape(int i) const
979 {
980 auto shape_ptr = get_shape_raw();
981 return shape_ptr[i];
982 }
983
984 const py::ssize_t *get_strides_raw() const
985 {
986 PyUSMArrayObject *raw_ar = usm_array_ptr();
987 return raw_ar->strides_;
988 }
989
990 std::vector<py::ssize_t> get_strides_vector() const
991 {
992 auto raw_st = get_strides_raw();
993 auto nd = get_ndim();
994
995 if (raw_st == nullptr) {
996 auto is_c_contig = is_c_contiguous();
997 auto is_f_contig = is_f_contiguous();
998 auto raw_sh = get_shape_raw();
999 if (is_c_contig) {
1000 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1001 return contig_strides;
1002 }
1003 else if (is_f_contig) {
1004 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1005 return contig_strides;
1006 }
1007 else {
1008 throw std::runtime_error("Invalid array encountered when "
1009 "building strides");
1010 }
1011 }
1012 else {
1013 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1014 return st_vec;
1015 }
1016 }
1017
1018 py::ssize_t get_size() const
1019 {
1020 PyUSMArrayObject *raw_ar = usm_array_ptr();
1021
1022 int ndim = raw_ar->nd_;
1023 const py::ssize_t *shape = raw_ar->shape_;
1024
1025 py::ssize_t nelems = 1;
1026 for (int i = 0; i < ndim; ++i) {
1027 nelems *= shape[i];
1028 }
1029
1030 assert(nelems >= 0);
1031 return nelems;
1032 }
1033
1034 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1035 {
1036 PyUSMArrayObject *raw_ar = usm_array_ptr();
1037
1038 int nd = raw_ar->nd_;
1039 const py::ssize_t *shape = raw_ar->shape_;
1040 const py::ssize_t *strides = raw_ar->strides_;
1041
1042 py::ssize_t offset_min = 0;
1043 py::ssize_t offset_max = 0;
1044 if (strides == nullptr) {
1045 py::ssize_t stride(1);
1046 for (int i = 0; i < nd; ++i) {
1047 offset_max += stride * (shape[i] - 1);
1048 stride *= shape[i];
1049 }
1050 }
1051 else {
1052 for (int i = 0; i < nd; ++i) {
1053 py::ssize_t delta = strides[i] * (shape[i] - 1);
1054 if (strides[i] > 0) {
1055 offset_max += delta;
1056 }
1057 else {
1058 offset_min += delta;
1059 }
1060 }
1061 }
1062 return std::make_pair(offset_min, offset_max);
1063 }
1064
1065 sycl::queue get_queue() const
1066 {
1067 PyUSMArrayObject *raw_ar = usm_array_ptr();
1068 Py_MemoryObject *mem_obj =
1069 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1070
1071 auto const &api = ::dpctl::detail::dpctl_capi::get();
1072 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1073 return *(reinterpret_cast<sycl::queue *>(QRef));
1074 }
1075
1076 sycl::device get_device() const
1077 {
1078 PyUSMArrayObject *raw_ar = usm_array_ptr();
1079 Py_MemoryObject *mem_obj =
1080 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1081
1082 auto const &api = ::dpctl::detail::dpctl_capi::get();
1083 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1084 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1085 }
1086
1087 int get_typenum() const
1088 {
1089 PyUSMArrayObject *raw_ar = usm_array_ptr();
1090 return raw_ar->typenum_;
1091 }
1092
1093 int get_flags() const
1094 {
1095 PyUSMArrayObject *raw_ar = usm_array_ptr();
1096 return raw_ar->flags_;
1097 }
1098
1099 int get_elemsize() const
1100 {
1101 int typenum = get_typenum();
1102 auto const &api = ::dpctl::detail::dpctl_capi::get();
1103
1104 // Lookup table for element sizes based on typenum
1105 if (typenum == api.UAR_BOOL_)
1106 return 1;
1107 if (typenum == api.UAR_BYTE_)
1108 return 1;
1109 if (typenum == api.UAR_UBYTE_)
1110 return 1;
1111 if (typenum == api.UAR_SHORT_)
1112 return 2;
1113 if (typenum == api.UAR_USHORT_)
1114 return 2;
1115 if (typenum == api.UAR_INT_)
1116 return 4;
1117 if (typenum == api.UAR_UINT_)
1118 return 4;
1119 if (typenum == api.UAR_LONG_)
1120 return sizeof(long);
1121 if (typenum == api.UAR_ULONG_)
1122 return sizeof(unsigned long);
1123 if (typenum == api.UAR_LONGLONG_)
1124 return 8;
1125 if (typenum == api.UAR_ULONGLONG_)
1126 return 8;
1127 if (typenum == api.UAR_FLOAT_)
1128 return 4;
1129 if (typenum == api.UAR_DOUBLE_)
1130 return 8;
1131 if (typenum == api.UAR_CFLOAT_)
1132 return 8;
1133 if (typenum == api.UAR_CDOUBLE_)
1134 return 16;
1135 if (typenum == api.UAR_HALF_)
1136 return 2;
1137
1138 return 0; // Unknown type
1139 }
1140
1141 bool is_c_contiguous() const
1142 {
1143 int flags = get_flags();
1144 auto const &api = ::dpctl::detail::dpctl_capi::get();
1145 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1146 }
1147
1148 bool is_f_contiguous() const
1149 {
1150 int flags = get_flags();
1151 auto const &api = ::dpctl::detail::dpctl_capi::get();
1152 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1153 }
1154
1155 bool is_writable() const
1156 {
1157 int flags = get_flags();
1158 auto const &api = ::dpctl::detail::dpctl_capi::get();
1159 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1160 }
1161
1163 py::object get_usm_data() const
1164 {
1165 PyUSMArrayObject *raw_ar = usm_array_ptr();
1166 // base_ is the Memory object - return new reference
1167 PyObject *usm_data = raw_ar->base_;
1168 Py_XINCREF(usm_data);
1169
1170 // pass reference ownership to py::object
1171 return py::reinterpret_steal<py::object>(usm_data);
1172 }
1173
1174 bool is_managed_by_smart_ptr() const
1175 {
1176 PyUSMArrayObject *raw_ar = usm_array_ptr();
1177 PyObject *usm_data = raw_ar->base_;
1178
1179 auto const &api = ::dpctl::detail::dpctl_capi::get();
1180 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1181 return false;
1182 }
1183
1184 Py_MemoryObject *mem_obj =
1185 reinterpret_cast<Py_MemoryObject *>(usm_data);
1186 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1187
1188 return bool(opaque_ptr);
1189 }
1190
1191 const std::shared_ptr<void> &get_smart_ptr_owner() const
1192 {
1193 PyUSMArrayObject *raw_ar = usm_array_ptr();
1194 PyObject *usm_data = raw_ar->base_;
1195
1196 auto const &api = ::dpctl::detail::dpctl_capi::get();
1197
1198 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1199 throw std::runtime_error(
1200 "usm_ndarray object does not have Memory object "
1201 "managing lifetime of USM allocation");
1202 }
1203
1204 Py_MemoryObject *mem_obj =
1205 reinterpret_cast<Py_MemoryObject *>(usm_data);
1206 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1207
1208 if (opaque_ptr) {
1209 auto shptr_ptr =
1210 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1211 return *shptr_ptr;
1212 }
1213 else {
1214 throw std::runtime_error(
1215 "Memory object underlying usm_ndarray does not have "
1216 "smart pointer managing lifetime of USM allocation");
1217 }
1218 }
1219
1220private:
1221 PyUSMArrayObject *usm_array_ptr() const
1222 {
1223 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1224 }
1225};
1226} // end namespace tensor
1227
1228namespace utils
1229{
1230namespace detail
1231{
1233{
1234
1235 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1236 {
1237 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1238 const auto &usm_memory_inst =
1239 py::cast<dpctl::memory::usm_memory>(h);
1240 return usm_memory_inst.is_managed_by_smart_ptr();
1241 }
1242 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1243 const auto &usm_array_inst =
1244 py::cast<dpctl::tensor::usm_ndarray>(h);
1245 return usm_array_inst.is_managed_by_smart_ptr();
1246 }
1247
1248 return false;
1249 }
1250
1251 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1252 {
1253 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1254 const auto &usm_memory_inst =
1255 py::cast<dpctl::memory::usm_memory>(h);
1256 return usm_memory_inst.get_smart_ptr_owner();
1257 }
1258 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1259 const auto &usm_array_inst =
1260 py::cast<dpctl::tensor::usm_ndarray>(h);
1261 return usm_array_inst.get_smart_ptr_owner();
1262 }
1263
1264 throw std::runtime_error(
1265 "Attempted extraction of shared_ptr on an unrecognized type");
1266 }
1267};
1268} // end of namespace detail
1269
1270template <std::size_t num>
1271sycl::event keep_args_alive(sycl::queue &q,
1272 const py::object (&py_objs)[num],
1273 const std::vector<sycl::event> &depends = {})
1274{
1275 std::size_t n_objects_held = 0;
1276 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1277
1278 std::size_t n_usm_owners_held = 0;
1279 std::array<std::shared_ptr<void>, num> shp_usm{};
1280
1281 for (std::size_t i = 0; i < num; ++i) {
1282 const auto &py_obj_i = py_objs[i];
1283 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1284 const auto &shp =
1285 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1286 shp_usm[n_usm_owners_held] = shp;
1287 ++n_usm_owners_held;
1288 }
1289 else {
1290 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1291 shp_arr[n_objects_held]->inc_ref();
1292 ++n_objects_held;
1293 }
1294 }
1295
1296 bool use_depends = true;
1297 sycl::event host_task_ev;
1298
1299 if (n_usm_owners_held > 0) {
1300 host_task_ev = q.submit([&](sycl::handler &cgh) {
1301 if (use_depends) {
1302 cgh.depends_on(depends);
1303 use_depends = false;
1304 }
1305 else {
1306 cgh.depends_on(host_task_ev);
1307 }
1308 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1309 // no body, but shared pointers are captured in
1310 // the lambda, ensuring that USM allocation is
1311 // kept alive
1312 });
1313 });
1314 }
1315
1316 if (n_objects_held > 0) {
1317 host_task_ev = q.submit([&](sycl::handler &cgh) {
1318 if (use_depends) {
1319 cgh.depends_on(depends);
1320 use_depends = false;
1321 }
1322 else {
1323 cgh.depends_on(host_task_ev);
1324 }
1325 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1326 py::gil_scoped_acquire acquire;
1327
1328 for (std::size_t i = 0; i < n_objects_held; ++i) {
1329 shp_arr[i]->dec_ref();
1330 }
1331 });
1332 });
1333 }
1334
1335 return host_task_ev;
1336}
1337
1340template <std::size_t num>
1341bool queues_are_compatible(const sycl::queue &exec_q,
1342 const sycl::queue (&alloc_qs)[num])
1343{
1344 for (std::size_t i = 0; i < num; ++i) {
1345
1346 if (exec_q != alloc_qs[i]) {
1347 return false;
1348 }
1349 }
1350 return true;
1351}
1352
1355template <std::size_t num>
1356bool queues_are_compatible(const sycl::queue &exec_q,
1357 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1358{
1359 for (std::size_t i = 0; i < num; ++i) {
1360
1361 if (exec_q != arrs[i].get_queue()) {
1362 return false;
1363 }
1364 }
1365 return true;
1366}
1367} // end namespace utils
1368} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.