DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpctl SYCL interface from external dpctl package
32#include "syclinterface/dpctl_sycl_extension_interface.h"
33#include "syclinterface/dpctl_sycl_types.h"
34
35#ifdef __cplusplus
36#define CYTHON_EXTERN_C extern "C"
37#else
38#define CYTHON_EXTERN_C
39#endif
40
41// Include dpctl C-API headers (both declarations and import functions)
42#include "dpctl/_sycl_context.h"
43#include "dpctl/_sycl_context_api.h"
44#include "dpctl/_sycl_device.h"
45#include "dpctl/_sycl_device_api.h"
46#include "dpctl/_sycl_event.h"
47#include "dpctl/_sycl_event_api.h"
48#include "dpctl/_sycl_queue.h"
49#include "dpctl/_sycl_queue_api.h"
50#include "dpctl/memory/_memory.h"
51#include "dpctl/memory/_memory_api.h"
52#include "dpctl/program/_program.h"
53#include "dpctl/program/_program_api.h"
54
55// Include generated Cython headers for usm_ndarray
56// (struct definition and constants only)
57#include "dpnp/tensor/_usmarray.h"
58#include "dpnp/tensor/_usmarray_api.h"
59
60#include <array>
61#include <cassert>
62#include <complex>
63#include <cstddef> // for std::size_t for C++ linkage
64#include <cstdint>
65#include <memory>
66#include <stddef.h> // for size_t for C linkage
67#include <stdexcept>
68#include <type_traits>
69#include <utility>
70#include <vector>
71
72#include <pybind11/pybind11.h>
73
74#include <sycl/sycl.hpp>
75
76namespace py = pybind11;
77
78namespace dpctl
79{
80namespace detail
81{
82// Lookup a type according to its size, and return a value corresponding to the
83// NumPy typenum.
84template <typename Concrete>
85constexpr int platform_typeid_lookup()
86{
87 return -1;
88}
89
90template <typename Concrete, typename T, typename... Ts, typename... Ints>
91constexpr int platform_typeid_lookup(int I, Ints... Is)
92{
93 return sizeof(Concrete) == sizeof(T)
94 ? I
95 : platform_typeid_lookup<Concrete, Ts...>(Is...);
96}
97
99{
100public:
101 // dpctl type objects
102 PyTypeObject *Py_SyclDeviceType_;
103 PyTypeObject *PySyclDeviceType_;
104 PyTypeObject *Py_SyclContextType_;
105 PyTypeObject *PySyclContextType_;
106 PyTypeObject *Py_SyclEventType_;
107 PyTypeObject *PySyclEventType_;
108 PyTypeObject *Py_SyclQueueType_;
109 PyTypeObject *PySyclQueueType_;
110 PyTypeObject *Py_MemoryType_;
111 PyTypeObject *PyMemoryUSMDeviceType_;
112 PyTypeObject *PyMemoryUSMSharedType_;
113 PyTypeObject *PyMemoryUSMHostType_;
114 PyTypeObject *PyUSMArrayType_;
115 PyTypeObject *PySyclProgramType_;
116 PyTypeObject *PySyclKernelType_;
117
118 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
119 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
120
121 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
122 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
123
124 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
125 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
126
127 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
128 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
129
130 // memory
131 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
132 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
133 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
134 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
135 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
136 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
137 size_t,
138 DPCTLSyclQueueRef,
139 PyObject *);
140
141 // program
142 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
143 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
144
145 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
146 PySyclProgramObject *);
147 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
148
149 int USM_ARRAY_C_CONTIGUOUS_;
150 int USM_ARRAY_F_CONTIGUOUS_;
151 int USM_ARRAY_WRITABLE_;
152 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
153 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
154 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
155 UAR_HALF_;
156 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
157 UAR_INT64_, UAR_UINT64_;
158
159 bool PySyclDevice_Check_(PyObject *obj) const
160 {
161 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
162 }
163 bool PySyclContext_Check_(PyObject *obj) const
164 {
165 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
166 }
167 bool PySyclEvent_Check_(PyObject *obj) const
168 {
169 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
170 }
171 bool PySyclQueue_Check_(PyObject *obj) const
172 {
173 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
174 }
175 bool PySyclKernel_Check_(PyObject *obj) const
176 {
177 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
178 }
179 bool PySyclProgram_Check_(PyObject *obj) const
180 {
181 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
182 }
183
185 {
186 as_usm_memory_.reset();
187 default_usm_ndarray_.reset();
188 default_usm_memory_.reset();
189 default_sycl_queue_.reset();
190 };
191
192 static auto &get()
193 {
194 static dpctl_capi api{};
195 return api;
196 }
197
198 py::object default_sycl_queue_pyobj() { return *default_sycl_queue_; }
199 py::object default_usm_memory_pyobj() { return *default_usm_memory_; }
200 py::object default_usm_ndarray_pyobj() { return *default_usm_ndarray_; }
201 py::object as_usm_memory_pyobj() { return *as_usm_memory_; }
202
203private:
204 struct Deleter
205 {
206 void operator()(py::object *p) const
207 {
208 const bool initialized = Py_IsInitialized();
209#if PY_VERSION_HEX < 0x30d0000
210 const bool finalizing = _Py_IsFinalizing();
211#else
212 const bool finalizing = Py_IsFinalizing();
213#endif
214 const bool guard = initialized && !finalizing;
215
216 if (guard) {
217 delete p;
218 }
219 }
220 };
221
222 std::shared_ptr<py::object> default_sycl_queue_;
223 std::shared_ptr<py::object> default_usm_memory_;
224 std::shared_ptr<py::object> default_usm_ndarray_;
225 std::shared_ptr<py::object> as_usm_memory_;
226
227 dpctl_capi()
228 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
229 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
230 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
231 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
232 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
233 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
234 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
235 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
236 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
237 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
238 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
239 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
240 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
241 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
242 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
243 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
244 SyclProgram_Make_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
245 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
246 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
247 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
248 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
249 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
250 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
251 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
252 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
253 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
254
255 {
256 // Import dpctl SYCL interface modules
257 // This imports python modules and initializes pointers to Python types
258 import_dpctl___sycl_device();
259 import_dpctl___sycl_context();
260 import_dpctl___sycl_event();
261 import_dpctl___sycl_queue();
262 import_dpctl__memory___memory();
263 import_dpctl__program___program();
264 // Import dpnp tensor module for PyUSMArrayType
265 import_dpnp__tensor___usmarray();
266
267 // Python type objects for classes implemented by dpctl
268 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
269 this->PySyclDeviceType_ = &PySyclDeviceType;
270 this->Py_SyclContextType_ = &Py_SyclContextType;
271 this->PySyclContextType_ = &PySyclContextType;
272 this->Py_SyclEventType_ = &Py_SyclEventType;
273 this->PySyclEventType_ = &PySyclEventType;
274 this->Py_SyclQueueType_ = &Py_SyclQueueType;
275 this->PySyclQueueType_ = &PySyclQueueType;
276 this->Py_MemoryType_ = &Py_MemoryType;
277 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
278 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
279 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
280 this->PyUSMArrayType_ = &PyUSMArrayType;
281 this->PySyclProgramType_ = &PySyclProgramType;
282 this->PySyclKernelType_ = &PySyclKernelType;
283
284 // SyclDevice API
285 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
286 this->SyclDevice_Make_ = SyclDevice_Make;
287
288 // SyclContext API
289 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
290 this->SyclContext_Make_ = SyclContext_Make;
291
292 // SyclEvent API
293 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
294 this->SyclEvent_Make_ = SyclEvent_Make;
295
296 // SyclQueue API
297 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
298 this->SyclQueue_Make_ = SyclQueue_Make;
299
300 // dpctl.memory API
301 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
302 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
303 this->Memory_GetContextRef_ = Memory_GetContextRef;
304 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
305 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
306 this->Memory_Make_ = Memory_Make;
307
308 // dpctl.program API
309 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
310 this->SyclKernel_Make_ = SyclKernel_Make;
311 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
312 this->SyclProgram_Make_ = SyclProgram_Make;
313
314 // constants
315 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
316 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
317 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
318 this->UAR_BOOL_ = UAR_BOOL;
319 this->UAR_BYTE_ = UAR_BYTE;
320 this->UAR_UBYTE_ = UAR_UBYTE;
321 this->UAR_SHORT_ = UAR_SHORT;
322 this->UAR_USHORT_ = UAR_USHORT;
323 this->UAR_INT_ = UAR_INT;
324 this->UAR_UINT_ = UAR_UINT;
325 this->UAR_LONG_ = UAR_LONG;
326 this->UAR_ULONG_ = UAR_ULONG;
327 this->UAR_LONGLONG_ = UAR_LONGLONG;
328 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
329 this->UAR_FLOAT_ = UAR_FLOAT;
330 this->UAR_DOUBLE_ = UAR_DOUBLE;
331 this->UAR_CFLOAT_ = UAR_CFLOAT;
332 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
333 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
334 this->UAR_HALF_ = UAR_HALF;
335
336 // deduced disjoint types
337 this->UAR_INT8_ = UAR_BYTE;
338 this->UAR_UINT8_ = UAR_UBYTE;
339 this->UAR_INT16_ = UAR_SHORT;
340 this->UAR_UINT16_ = UAR_USHORT;
341 this->UAR_INT32_ =
342 platform_typeid_lookup<std::int32_t, long, int, short>(
343 UAR_LONG, UAR_INT, UAR_SHORT);
344 this->UAR_UINT32_ =
345 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
346 unsigned short>(UAR_ULONG, UAR_UINT,
347 UAR_USHORT);
348 this->UAR_INT64_ =
349 platform_typeid_lookup<std::int64_t, long, long long, int>(
350 UAR_LONG, UAR_LONGLONG, UAR_INT);
351 this->UAR_UINT64_ =
352 platform_typeid_lookup<std::uint64_t, unsigned long,
353 unsigned long long, unsigned int>(
354 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
355
356 // create shared pointers to python objects used in type-casters
357 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
358 sycl::queue q_{};
359 PySyclQueueObject *py_q_tmp =
360 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
361 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
362 reinterpret_cast<PyObject *>(py_q_tmp));
363
364 default_sycl_queue_ = std::shared_ptr<py::object>(
365 new py::object(py_sycl_queue), Deleter{});
366
367 py::module_ mod_memory = py::module_::import("dpctl.memory");
368 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
369 as_usm_memory_ = std::shared_ptr<py::object>(
370 new py::object{py_as_usm_memory}, Deleter{});
371
372 auto mem_kl = mod_memory.attr("MemoryUSMHost");
373 const py::object &py_default_usm_memory =
374 mem_kl(1, py::arg("queue") = py_sycl_queue);
375 default_usm_memory_ = std::shared_ptr<py::object>(
376 new py::object{py_default_usm_memory}, Deleter{});
377
378 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
379 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
380
381 const py::object &py_default_usm_ndarray =
382 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
383 py::arg("buffer") = py_default_usm_memory);
384
385 default_usm_ndarray_ = std::shared_ptr<py::object>(
386 new py::object{py_default_usm_ndarray}, Deleter{});
387 }
388
389 dpctl_capi(dpctl_capi const &) = default;
390 dpctl_capi &operator=(dpctl_capi const &) = default;
391 dpctl_capi &operator=(dpctl_capi &&) = default;
392
393}; // struct dpctl_capi
394} // namespace detail
395} // namespace dpctl
396
397namespace pybind11::detail
398{
399#define DPCTL_TYPE_CASTER(type, py_name) \
400protected: \
401 std::unique_ptr<type> value; \
402 \
403public: \
404 static constexpr auto name = py_name; \
405 template < \
406 typename T_, \
407 ::pybind11::detail::enable_if_t< \
408 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
409 int> = 0> \
410 static ::pybind11::handle cast(T_ *src, \
411 ::pybind11::return_value_policy policy, \
412 ::pybind11::handle parent) \
413 { \
414 if (!src) \
415 return ::pybind11::none().release(); \
416 if (policy == ::pybind11::return_value_policy::take_ownership) { \
417 auto h = cast(std::move(*src), policy, parent); \
418 delete src; \
419 return h; \
420 } \
421 return cast(*src, policy, parent); \
422 } \
423 operator type *() \
424 { \
425 return value.get(); \
426 } /* NOLINT(bugprone-macro-parentheses) */ \
427 operator type &() \
428 { \
429 return *value; \
430 } /* NOLINT(bugprone-macro-parentheses) */ \
431 operator type &&() && \
432 { \
433 return std::move(*value); \
434 } /* NOLINT(bugprone-macro-parentheses) */ \
435 template <typename T_> \
436 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
437
438/* This type caster associates ``sycl::queue`` C++ class with
439 * :class:`dpctl.SyclQueue` for the purposes of generation of
440 * Python bindings by pybind11.
441 */
442template <>
443struct type_caster<sycl::queue>
444{
445public:
446 bool load(handle src, bool)
447 {
448 PyObject *source = src.ptr();
449 auto const &api = ::dpctl::detail::dpctl_capi::get();
450 if (api.PySyclQueue_Check_(source)) {
451 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
452 reinterpret_cast<PySyclQueueObject *>(source));
453 value = std::make_unique<sycl::queue>(
454 *(reinterpret_cast<sycl::queue *>(QRef)));
455 return true;
456 }
457 else {
458 throw py::type_error(
459 "Input is of unexpected type, expected dpctl.SyclQueue");
460 }
461 }
462
463 static handle cast(sycl::queue src, return_value_policy, handle)
464 {
465 auto const &api = ::dpctl::detail::dpctl_capi::get();
466 auto tmp =
467 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
468 return handle(reinterpret_cast<PyObject *>(tmp));
469 }
470
471 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
472};
473
474/* This type caster associates ``sycl::device`` C++ class with
475 * :class:`dpctl.SyclDevice` for the purposes of generation of
476 * Python bindings by pybind11.
477 */
478template <>
479struct type_caster<sycl::device>
480{
481public:
482 bool load(handle src, bool)
483 {
484 PyObject *source = src.ptr();
485 auto const &api = ::dpctl::detail::dpctl_capi::get();
486 if (api.PySyclDevice_Check_(source)) {
487 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
488 reinterpret_cast<PySyclDeviceObject *>(source));
489 value = std::make_unique<sycl::device>(
490 *(reinterpret_cast<sycl::device *>(DRef)));
491 return true;
492 }
493 else {
494 throw py::type_error(
495 "Input is of unexpected type, expected dpctl.SyclDevice");
496 }
497 }
498
499 static handle cast(sycl::device src, return_value_policy, handle)
500 {
501 auto const &api = ::dpctl::detail::dpctl_capi::get();
502 auto tmp =
503 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
504 return handle(reinterpret_cast<PyObject *>(tmp));
505 }
506
507 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
508};
509
510/* This type caster associates ``sycl::context`` C++ class with
511 * :class:`dpctl.SyclContext` for the purposes of generation of
512 * Python bindings by pybind11.
513 */
514template <>
515struct type_caster<sycl::context>
516{
517public:
518 bool load(handle src, bool)
519 {
520 PyObject *source = src.ptr();
521 auto const &api = ::dpctl::detail::dpctl_capi::get();
522 if (api.PySyclContext_Check_(source)) {
523 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
524 reinterpret_cast<PySyclContextObject *>(source));
525 value = std::make_unique<sycl::context>(
526 *(reinterpret_cast<sycl::context *>(CRef)));
527 return true;
528 }
529 else {
530 throw py::type_error(
531 "Input is of unexpected type, expected dpctl.SyclContext");
532 }
533 }
534
535 static handle cast(sycl::context src, return_value_policy, handle)
536 {
537 auto const &api = ::dpctl::detail::dpctl_capi::get();
538 auto tmp =
539 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
540 return handle(reinterpret_cast<PyObject *>(tmp));
541 }
542
543 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
544};
545
546/* This type caster associates ``sycl::event`` C++ class with
547 * :class:`dpctl.SyclEvent` for the purposes of generation of
548 * Python bindings by pybind11.
549 */
550template <>
551struct type_caster<sycl::event>
552{
553public:
554 bool load(handle src, bool)
555 {
556 PyObject *source = src.ptr();
557 auto const &api = ::dpctl::detail::dpctl_capi::get();
558 if (api.PySyclEvent_Check_(source)) {
559 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
560 reinterpret_cast<PySyclEventObject *>(source));
561 value = std::make_unique<sycl::event>(
562 *(reinterpret_cast<sycl::event *>(ERef)));
563 return true;
564 }
565 else {
566 throw py::type_error(
567 "Input is of unexpected type, expected dpctl.SyclEvent");
568 }
569 }
570
571 static handle cast(sycl::event src, return_value_policy, handle)
572 {
573 auto const &api = ::dpctl::detail::dpctl_capi::get();
574 auto tmp =
575 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
576 return handle(reinterpret_cast<PyObject *>(tmp));
577 }
578
579 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
580};
581
582/* This type caster associates ``sycl::kernel`` C++ class with
583 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
584 * Python bindings by pybind11.
585 */
586template <>
587struct type_caster<sycl::kernel>
588{
589public:
590 bool load(handle src, bool)
591 {
592 PyObject *source = src.ptr();
593 auto const &api = ::dpctl::detail::dpctl_capi::get();
594 if (api.PySyclKernel_Check_(source)) {
595 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
596 reinterpret_cast<PySyclKernelObject *>(source));
597 value = std::make_unique<sycl::kernel>(
598 *(reinterpret_cast<sycl::kernel *>(KRef)));
599 return true;
600 }
601 else {
602 throw py::type_error("Input is of unexpected type, expected "
603 "dpctl.program.SyclKernel");
604 }
605 }
606
607 static handle cast(sycl::kernel src, return_value_policy, handle)
608 {
609 auto const &api = ::dpctl::detail::dpctl_capi::get();
610 auto tmp =
611 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
612 "dpctl4pybind11_kernel");
613 return handle(reinterpret_cast<PyObject *>(tmp));
614 }
615
616 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
617};
618
619/* This type caster associates
620 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
621 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
622 * Python bindings by pybind11.
623 */
624template <>
625struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
626{
627public:
628 bool load(handle src, bool)
629 {
630 PyObject *source = src.ptr();
631 auto const &api = ::dpctl::detail::dpctl_capi::get();
632 if (api.PySyclProgram_Check_(source)) {
633 DPCTLSyclKernelBundleRef KBRef =
634 api.SyclProgram_GetKernelBundleRef_(
635 reinterpret_cast<PySyclProgramObject *>(source));
636 value = std::make_unique<
637 sycl::kernel_bundle<sycl::bundle_state::executable>>(
638 *(reinterpret_cast<
639 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
640 KBRef)));
641 return true;
642 }
643 else {
644 throw py::type_error("Input is of unexpected type, expected "
645 "dpctl.program.SyclProgram");
646 }
647 }
648
649 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
650 return_value_policy,
651 handle)
652 {
653 auto const &api = ::dpctl::detail::dpctl_capi::get();
654 auto tmp = api.SyclProgram_Make_(
655 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
656 return handle(reinterpret_cast<PyObject *>(tmp));
657 }
658
659 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
660 _("dpctl.program.SyclProgram"));
661};
662
663/* This type caster associates
664 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
665 * of generation of Python bindings by pybind11.
666 */
667template <>
668struct type_caster<sycl::half>
669{
670public:
671 bool load(handle src, bool convert)
672 {
673 double py_value;
674
675 if (!src) {
676 return false;
677 }
678
679 PyObject *source = src.ptr();
680
681 if (convert || PyFloat_Check(source)) {
682 py_value = PyFloat_AsDouble(source);
683 }
684 else {
685 return false;
686 }
687
688 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
689
690 if (py_err) {
691 PyErr_Clear();
692 if (convert && (PyNumber_Check(source) != 0)) {
693 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
694 return load(tmp, false);
695 }
696 return false;
697 }
698 value = static_cast<sycl::half>(py_value);
699 return true;
700 }
701
702 static handle cast(sycl::half src, return_value_policy, handle)
703 {
704 return PyFloat_FromDouble(static_cast<double>(src));
705 }
706
707 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
708};
709} // namespace pybind11::detail
710
711namespace dpctl
712{
713namespace memory
714{
715// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
716// this allows to avoid compilation error
717using pybind11::error_already_set;
718
719class usm_memory : public py::object
720{
721public:
722 PYBIND11_OBJECT_CVT(
724 py::object,
725 [](PyObject *o) -> bool {
726 return PyObject_TypeCheck(
727 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
728 0;
729 },
730 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
731
732 usm_memory()
733 : py::object(
734 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
735 borrowed_t{})
736 {
737 if (!m_ptr)
738 throw py::error_already_set();
739 }
740
744 usm_memory(void *usm_ptr,
745 std::size_t nbytes,
746 const sycl::queue &q,
747 std::shared_ptr<void> shptr)
748 {
749 auto const &api = ::dpctl::detail::dpctl_capi::get();
750 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
751 auto q_uptr = std::make_unique<sycl::queue>(q);
752 DPCTLSyclQueueRef QRef =
753 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
754
755 auto vacuous_destructor = []() {};
756 py::capsule mock_owner(vacuous_destructor);
757
758 // create memory object owned by mock_owner, it is a new reference
759 PyObject *_memory =
760 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
761 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
762
763 using py_uptrT =
764 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
765
766 if (!_memory) {
767 throw py::error_already_set();
768 }
769
770 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
771 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
772
773 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
774 // replace mock_owner capsule as the owner
775 memobj->refobj = Py_None;
776 // set opaque ptr field, usm_memory now knowns that USM is managed
777 // by smart pointer
778 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
779
780 // _memory will delete created copies of sycl::queue, and
781 // std::shared_ptr and the deleter of the shared_ptr<void> is
782 // supposed to free the USM allocation
783 m_ptr = _memory;
784 q_uptr.release();
785 memory_uptr.release();
786 }
787
788 sycl::queue get_queue() const
789 {
790 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
791 auto const &api = ::dpctl::detail::dpctl_capi::get();
792 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
793 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
794 return *obj_q;
795 }
796
797 char *get_pointer() const
798 {
799 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
800 auto const &api = ::dpctl::detail::dpctl_capi::get();
801 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
802 return reinterpret_cast<char *>(MRef);
803 }
804
805 std::size_t get_nbytes() const
806 {
807 auto const &api = ::dpctl::detail::dpctl_capi::get();
808 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
809 return api.Memory_GetNumBytes_(mem_obj);
810 }
811
812 bool is_managed_by_smart_ptr() const
813 {
814 auto const &api = ::dpctl::detail::dpctl_capi::get();
815 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
816 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
817
818 return bool(opaque_ptr);
819 }
820
821 const std::shared_ptr<void> &get_smart_ptr_owner() const
822 {
823 auto const &api = ::dpctl::detail::dpctl_capi::get();
824 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
825 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
826
827 if (opaque_ptr) {
828 auto shptr_ptr =
829 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
830 return *shptr_ptr;
831 }
832 else {
833 throw std::runtime_error(
834 "Memory object does not have smart pointer "
835 "managing lifetime of USM allocation");
836 }
837 }
838
839protected:
840 static PyObject *as_usm_memory(PyObject *o)
841 {
842 if (o == nullptr) {
843 PyErr_SetString(PyExc_ValueError,
844 "cannot create a usm_memory from a nullptr");
845 return nullptr;
846 }
847
848 auto converter =
849 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
850
851 py::object res;
852 try {
853 res = converter(py::handle(o));
854 } catch (const py::error_already_set &e) {
855 return nullptr;
856 }
857 return res.ptr();
858 }
859};
860} // end namespace memory
861
862namespace tensor
863{
864inline std::vector<py::ssize_t>
865 c_contiguous_strides(int nd,
866 const py::ssize_t *shape,
867 py::ssize_t element_size = 1)
868{
869 if (nd > 0) {
870 std::vector<py::ssize_t> c_strides(nd, element_size);
871 for (int ic = nd - 1; ic > 0;) {
872 py::ssize_t next_v = c_strides[ic] * shape[ic];
873 c_strides[--ic] = next_v;
874 }
875 return c_strides;
876 }
877 else {
878 return std::vector<py::ssize_t>();
879 }
880}
881
882inline std::vector<py::ssize_t>
883 f_contiguous_strides(int nd,
884 const py::ssize_t *shape,
885 py::ssize_t element_size = 1)
886{
887 if (nd > 0) {
888 std::vector<py::ssize_t> f_strides(nd, element_size);
889 for (int i = 0; i < nd - 1;) {
890 py::ssize_t next_v = f_strides[i] * shape[i];
891 f_strides[++i] = next_v;
892 }
893 return f_strides;
894 }
895 else {
896 return std::vector<py::ssize_t>();
897 }
898}
899
900inline std::vector<py::ssize_t>
901 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
902 py::ssize_t element_size = 1)
903{
904 return c_contiguous_strides(shape.size(), shape.data(), element_size);
905}
906
907inline std::vector<py::ssize_t>
908 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
909 py::ssize_t element_size = 1)
910{
911 return f_contiguous_strides(shape.size(), shape.data(), element_size);
912}
913
914class usm_ndarray : public py::object
915{
916public:
917 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
918 return PyObject_TypeCheck(
919 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
920 })
921
923 : py::object(
924 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
925 borrowed_t{})
926 {
927 if (!m_ptr)
928 throw py::error_already_set();
929 }
930
931 char *get_data() const
932 {
933 PyUSMArrayObject *raw_ar = usm_array_ptr();
934 return raw_ar->data_;
935 }
936
937 template <typename T>
938 T *get_data() const
939 {
940 return reinterpret_cast<T *>(get_data());
941 }
942
943 int get_ndim() const
944 {
945 PyUSMArrayObject *raw_ar = usm_array_ptr();
946 return raw_ar->nd_;
947 }
948
949 const py::ssize_t *get_shape_raw() const
950 {
951 PyUSMArrayObject *raw_ar = usm_array_ptr();
952 return raw_ar->shape_;
953 }
954
955 std::vector<py::ssize_t> get_shape_vector() const
956 {
957 auto raw_sh = get_shape_raw();
958 auto nd = get_ndim();
959
960 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
961 return shape_vector;
962 }
963
964 py::ssize_t get_shape(int i) const
965 {
966 auto shape_ptr = get_shape_raw();
967 return shape_ptr[i];
968 }
969
970 const py::ssize_t *get_strides_raw() const
971 {
972 PyUSMArrayObject *raw_ar = usm_array_ptr();
973 return raw_ar->strides_;
974 }
975
976 std::vector<py::ssize_t> get_strides_vector() const
977 {
978 auto raw_st = get_strides_raw();
979 auto nd = get_ndim();
980
981 if (raw_st == nullptr) {
982 auto is_c_contig = is_c_contiguous();
983 auto is_f_contig = is_f_contiguous();
984 auto raw_sh = get_shape_raw();
985 if (is_c_contig) {
986 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
987 return contig_strides;
988 }
989 else if (is_f_contig) {
990 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
991 return contig_strides;
992 }
993 else {
994 throw std::runtime_error("Invalid array encountered when "
995 "building strides");
996 }
997 }
998 else {
999 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1000 return st_vec;
1001 }
1002 }
1003
1004 py::ssize_t get_size() const
1005 {
1006 PyUSMArrayObject *raw_ar = usm_array_ptr();
1007
1008 int ndim = raw_ar->nd_;
1009 const py::ssize_t *shape = raw_ar->shape_;
1010
1011 py::ssize_t nelems = 1;
1012 for (int i = 0; i < ndim; ++i) {
1013 nelems *= shape[i];
1014 }
1015
1016 assert(nelems >= 0);
1017 return nelems;
1018 }
1019
1020 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1021 {
1022 PyUSMArrayObject *raw_ar = usm_array_ptr();
1023
1024 int nd = raw_ar->nd_;
1025 const py::ssize_t *shape = raw_ar->shape_;
1026 const py::ssize_t *strides = raw_ar->strides_;
1027
1028 py::ssize_t offset_min = 0;
1029 py::ssize_t offset_max = 0;
1030 if (strides == nullptr) {
1031 py::ssize_t stride(1);
1032 for (int i = 0; i < nd; ++i) {
1033 offset_max += stride * (shape[i] - 1);
1034 stride *= shape[i];
1035 }
1036 }
1037 else {
1038 for (int i = 0; i < nd; ++i) {
1039 py::ssize_t delta = strides[i] * (shape[i] - 1);
1040 if (strides[i] > 0) {
1041 offset_max += delta;
1042 }
1043 else {
1044 offset_min += delta;
1045 }
1046 }
1047 }
1048 return std::make_pair(offset_min, offset_max);
1049 }
1050
1051 sycl::queue get_queue() const
1052 {
1053 PyUSMArrayObject *raw_ar = usm_array_ptr();
1054 Py_MemoryObject *mem_obj =
1055 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1056
1057 auto const &api = ::dpctl::detail::dpctl_capi::get();
1058 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1059 return *(reinterpret_cast<sycl::queue *>(QRef));
1060 }
1061
1062 sycl::device get_device() const
1063 {
1064 PyUSMArrayObject *raw_ar = usm_array_ptr();
1065 Py_MemoryObject *mem_obj =
1066 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1067
1068 auto const &api = ::dpctl::detail::dpctl_capi::get();
1069 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1070 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1071 }
1072
1073 int get_typenum() const
1074 {
1075 PyUSMArrayObject *raw_ar = usm_array_ptr();
1076 return raw_ar->typenum_;
1077 }
1078
1079 int get_flags() const
1080 {
1081 PyUSMArrayObject *raw_ar = usm_array_ptr();
1082 return raw_ar->flags_;
1083 }
1084
1085 int get_elemsize() const
1086 {
1087 int typenum = get_typenum();
1088 auto const &api = ::dpctl::detail::dpctl_capi::get();
1089
1090 // Lookup table for element sizes based on typenum
1091 if (typenum == api.UAR_BOOL_)
1092 return 1;
1093 if (typenum == api.UAR_BYTE_)
1094 return 1;
1095 if (typenum == api.UAR_UBYTE_)
1096 return 1;
1097 if (typenum == api.UAR_SHORT_)
1098 return 2;
1099 if (typenum == api.UAR_USHORT_)
1100 return 2;
1101 if (typenum == api.UAR_INT_)
1102 return 4;
1103 if (typenum == api.UAR_UINT_)
1104 return 4;
1105 if (typenum == api.UAR_LONG_)
1106 return sizeof(long);
1107 if (typenum == api.UAR_ULONG_)
1108 return sizeof(unsigned long);
1109 if (typenum == api.UAR_LONGLONG_)
1110 return 8;
1111 if (typenum == api.UAR_ULONGLONG_)
1112 return 8;
1113 if (typenum == api.UAR_FLOAT_)
1114 return 4;
1115 if (typenum == api.UAR_DOUBLE_)
1116 return 8;
1117 if (typenum == api.UAR_CFLOAT_)
1118 return 8;
1119 if (typenum == api.UAR_CDOUBLE_)
1120 return 16;
1121 if (typenum == api.UAR_HALF_)
1122 return 2;
1123
1124 return 0; // Unknown type
1125 }
1126
1127 bool is_c_contiguous() const
1128 {
1129 int flags = get_flags();
1130 auto const &api = ::dpctl::detail::dpctl_capi::get();
1131 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1132 }
1133
1134 bool is_f_contiguous() const
1135 {
1136 int flags = get_flags();
1137 auto const &api = ::dpctl::detail::dpctl_capi::get();
1138 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1139 }
1140
1141 bool is_writable() const
1142 {
1143 int flags = get_flags();
1144 auto const &api = ::dpctl::detail::dpctl_capi::get();
1145 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1146 }
1147
1149 py::object get_usm_data() const
1150 {
1151 PyUSMArrayObject *raw_ar = usm_array_ptr();
1152 // base_ is the Memory object - return new reference
1153 PyObject *usm_data = raw_ar->base_;
1154 Py_XINCREF(usm_data);
1155
1156 // pass reference ownership to py::object
1157 return py::reinterpret_steal<py::object>(usm_data);
1158 }
1159
1160 bool is_managed_by_smart_ptr() const
1161 {
1162 PyUSMArrayObject *raw_ar = usm_array_ptr();
1163 PyObject *usm_data = raw_ar->base_;
1164
1165 auto const &api = ::dpctl::detail::dpctl_capi::get();
1166 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1167 return false;
1168 }
1169
1170 Py_MemoryObject *mem_obj =
1171 reinterpret_cast<Py_MemoryObject *>(usm_data);
1172 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1173
1174 return bool(opaque_ptr);
1175 }
1176
1177 const std::shared_ptr<void> &get_smart_ptr_owner() const
1178 {
1179 PyUSMArrayObject *raw_ar = usm_array_ptr();
1180 PyObject *usm_data = raw_ar->base_;
1181
1182 auto const &api = ::dpctl::detail::dpctl_capi::get();
1183
1184 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1185 throw std::runtime_error(
1186 "usm_ndarray object does not have Memory object "
1187 "managing lifetime of USM allocation");
1188 }
1189
1190 Py_MemoryObject *mem_obj =
1191 reinterpret_cast<Py_MemoryObject *>(usm_data);
1192 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1193
1194 if (opaque_ptr) {
1195 auto shptr_ptr =
1196 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1197 return *shptr_ptr;
1198 }
1199 else {
1200 throw std::runtime_error(
1201 "Memory object underlying usm_ndarray does not have "
1202 "smart pointer managing lifetime of USM allocation");
1203 }
1204 }
1205
1206private:
1207 PyUSMArrayObject *usm_array_ptr() const
1208 {
1209 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1210 }
1211};
1212} // end namespace tensor
1213
1214namespace utils
1215{
1216namespace detail
1217{
1219{
1220
1221 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1222 {
1223 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1224 const auto &usm_memory_inst =
1225 py::cast<dpctl::memory::usm_memory>(h);
1226 return usm_memory_inst.is_managed_by_smart_ptr();
1227 }
1228 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1229 const auto &usm_array_inst =
1230 py::cast<dpctl::tensor::usm_ndarray>(h);
1231 return usm_array_inst.is_managed_by_smart_ptr();
1232 }
1233
1234 return false;
1235 }
1236
1237 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1238 {
1239 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1240 const auto &usm_memory_inst =
1241 py::cast<dpctl::memory::usm_memory>(h);
1242 return usm_memory_inst.get_smart_ptr_owner();
1243 }
1244 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1245 const auto &usm_array_inst =
1246 py::cast<dpctl::tensor::usm_ndarray>(h);
1247 return usm_array_inst.get_smart_ptr_owner();
1248 }
1249
1250 throw std::runtime_error(
1251 "Attempted extraction of shared_ptr on an unrecognized type");
1252 }
1253};
1254} // end of namespace detail
1255
1256template <std::size_t num>
1257sycl::event keep_args_alive(sycl::queue &q,
1258 const py::object (&py_objs)[num],
1259 const std::vector<sycl::event> &depends = {})
1260{
1261 std::size_t n_objects_held = 0;
1262 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1263
1264 std::size_t n_usm_owners_held = 0;
1265 std::array<std::shared_ptr<void>, num> shp_usm{};
1266
1267 for (std::size_t i = 0; i < num; ++i) {
1268 const auto &py_obj_i = py_objs[i];
1269 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1270 const auto &shp =
1271 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1272 shp_usm[n_usm_owners_held] = shp;
1273 ++n_usm_owners_held;
1274 }
1275 else {
1276 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1277 shp_arr[n_objects_held]->inc_ref();
1278 ++n_objects_held;
1279 }
1280 }
1281
1282 bool use_depends = true;
1283 sycl::event host_task_ev;
1284
1285 if (n_usm_owners_held > 0) {
1286 host_task_ev = q.submit([&](sycl::handler &cgh) {
1287 if (use_depends) {
1288 cgh.depends_on(depends);
1289 use_depends = false;
1290 }
1291 else {
1292 cgh.depends_on(host_task_ev);
1293 }
1294 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1295 // no body, but shared pointers are captured in
1296 // the lambda, ensuring that USM allocation is
1297 // kept alive
1298 });
1299 });
1300 }
1301
1302 if (n_objects_held > 0) {
1303 host_task_ev = q.submit([&](sycl::handler &cgh) {
1304 if (use_depends) {
1305 cgh.depends_on(depends);
1306 use_depends = false;
1307 }
1308 else {
1309 cgh.depends_on(host_task_ev);
1310 }
1311 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1312 py::gil_scoped_acquire acquire;
1313
1314 for (std::size_t i = 0; i < n_objects_held; ++i) {
1315 shp_arr[i]->dec_ref();
1316 }
1317 });
1318 });
1319 }
1320
1321 return host_task_ev;
1322}
1323
1326template <std::size_t num>
1327bool queues_are_compatible(const sycl::queue &exec_q,
1328 const sycl::queue (&alloc_qs)[num])
1329{
1330 for (std::size_t i = 0; i < num; ++i) {
1331
1332 if (exec_q != alloc_qs[i]) {
1333 return false;
1334 }
1335 }
1336 return true;
1337}
1338
1341template <std::size_t num>
1342bool queues_are_compatible(const sycl::queue &exec_q,
1343 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1344{
1345 for (std::size_t i = 0; i < num; ++i) {
1346
1347 if (exec_q != arrs[i].get_queue()) {
1348 return false;
1349 }
1350 }
1351 return true;
1352}
1353} // end namespace utils
1354} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.