DPNP C++ backend kernel library 0.20.0dev3
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpctl SYCL interface from external dpctl package
32#include "syclinterface/dpctl_sycl_extension_interface.h"
33#include "syclinterface/dpctl_sycl_types.h"
34
35#ifdef __cplusplus
36#define CYTHON_EXTERN_C extern "C"
37#else
38#define CYTHON_EXTERN_C
39#endif
40
41// Include dpctl C-API headers (both declarations and import functions)
42#include "dpctl/_sycl_context.h"
43#include "dpctl/_sycl_context_api.h"
44#include "dpctl/_sycl_device.h"
45#include "dpctl/_sycl_device_api.h"
46#include "dpctl/_sycl_event.h"
47#include "dpctl/_sycl_event_api.h"
48#include "dpctl/_sycl_queue.h"
49#include "dpctl/_sycl_queue_api.h"
50#include "dpctl/memory/_memory.h"
51#include "dpctl/memory/_memory_api.h"
52#include "dpctl/program/_program.h"
53#include "dpctl/program/_program_api.h"
54
55// Include generated Cython headers for usm_ndarray
56// (struct definition and constants only)
57#include "dpnp/tensor/_usmarray.h"
58#include "dpnp/tensor/_usmarray_api.h"
59
60#include <array>
61#include <cassert>
62#include <complex>
63#include <cstddef> // for std::size_t for C++ linkage
64#include <cstdint>
65#include <memory>
66#include <stddef.h> // for size_t for C linkage
67#include <stdexcept>
68#include <type_traits>
69#include <utility>
70#include <vector>
71
72#include <pybind11/pybind11.h>
73
74#include <sycl/sycl.hpp>
75
76namespace py = pybind11;
77
78namespace dpctl
79{
80namespace detail
81{
82// Lookup a type according to its size, and return a value corresponding to the
83// NumPy typenum.
84template <typename Concrete>
85constexpr int platform_typeid_lookup()
86{
87 return -1;
88}
89
90template <typename Concrete, typename T, typename... Ts, typename... Ints>
91constexpr int platform_typeid_lookup(int I, Ints... Is)
92{
93 return sizeof(Concrete) == sizeof(T)
94 ? I
95 : platform_typeid_lookup<Concrete, Ts...>(Is...);
96}
97
99{
100public:
101 // dpctl type objects
102 PyTypeObject *Py_SyclDeviceType_;
103 PyTypeObject *PySyclDeviceType_;
104 PyTypeObject *Py_SyclContextType_;
105 PyTypeObject *PySyclContextType_;
106 PyTypeObject *Py_SyclEventType_;
107 PyTypeObject *PySyclEventType_;
108 PyTypeObject *Py_SyclQueueType_;
109 PyTypeObject *PySyclQueueType_;
110 PyTypeObject *Py_MemoryType_;
111 PyTypeObject *PyMemoryUSMDeviceType_;
112 PyTypeObject *PyMemoryUSMSharedType_;
113 PyTypeObject *PyMemoryUSMHostType_;
114 PyTypeObject *PyUSMArrayType_;
115 PyTypeObject *PySyclProgramType_;
116 PyTypeObject *PySyclKernelType_;
117
118 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
119 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
120
121 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
122 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
123
124 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
125 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
126
127 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
128 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
129
130 // memory
131 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
132 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
133 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
134 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
135 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
136 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
137 size_t,
138 DPCTLSyclQueueRef,
139 PyObject *);
140
141 // program
142 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
143 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
144
145 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
146 PySyclProgramObject *);
147 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
148
149 int USM_ARRAY_C_CONTIGUOUS_;
150 int USM_ARRAY_F_CONTIGUOUS_;
151 int USM_ARRAY_WRITABLE_;
152 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
153 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
154 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
155 UAR_HALF_;
156 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
157 UAR_INT64_, UAR_UINT64_;
158
159 bool PySyclDevice_Check_(PyObject *obj) const
160 {
161 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
162 }
163 bool PySyclContext_Check_(PyObject *obj) const
164 {
165 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
166 }
167 bool PySyclEvent_Check_(PyObject *obj) const
168 {
169 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
170 }
171 bool PySyclQueue_Check_(PyObject *obj) const
172 {
173 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
174 }
175 bool PySyclKernel_Check_(PyObject *obj) const
176 {
177 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
178 }
179 bool PySyclProgram_Check_(PyObject *obj) const
180 {
181 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
182 }
183
185 {
186 as_usm_memory_.reset();
187 default_usm_ndarray_.reset();
188 default_usm_memory_.reset();
189 default_sycl_queue_.reset();
190 };
191
192 static auto &get()
193 {
194 static dpctl_capi api{};
195 return api;
196 }
197
198 py::object default_sycl_queue_pyobj()
199 {
200 return *default_sycl_queue_;
201 }
202 py::object default_usm_memory_pyobj()
203 {
204 return *default_usm_memory_;
205 }
206 py::object default_usm_ndarray_pyobj()
207 {
208 return *default_usm_ndarray_;
209 }
210 py::object as_usm_memory_pyobj()
211 {
212 return *as_usm_memory_;
213 }
214
215private:
216 struct Deleter
217 {
218 void operator()(py::object *p) const
219 {
220 const bool initialized = Py_IsInitialized();
221#if PY_VERSION_HEX < 0x30d0000
222 const bool finalizing = _Py_IsFinalizing();
223#else
224 const bool finalizing = Py_IsFinalizing();
225#endif
226 const bool guard = initialized && !finalizing;
227
228 if (guard) {
229 delete p;
230 }
231 }
232 };
233
234 std::shared_ptr<py::object> default_sycl_queue_;
235 std::shared_ptr<py::object> default_usm_memory_;
236 std::shared_ptr<py::object> default_usm_ndarray_;
237 std::shared_ptr<py::object> as_usm_memory_;
238
239 dpctl_capi()
240 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
241 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
242 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
243 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
244 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
245 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
246 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
247 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
248 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
249 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
250 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
251 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
252 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
253 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
254 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
255 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
256 SyclProgram_Make_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
257 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
258 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
259 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
260 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
261 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
262 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
263 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
264 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
265 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
266
267 {
268 // Import dpctl SYCL interface modules
269 // This imports python modules and initializes pointers to Python types
270 import_dpctl___sycl_device();
271 import_dpctl___sycl_context();
272 import_dpctl___sycl_event();
273 import_dpctl___sycl_queue();
274 import_dpctl__memory___memory();
275 import_dpctl__program___program();
276 // Import dpnp tensor module for PyUSMArrayType
277 import_dpnp__tensor___usmarray();
278
279 // Python type objects for classes implemented by dpctl
280 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
281 this->PySyclDeviceType_ = &PySyclDeviceType;
282 this->Py_SyclContextType_ = &Py_SyclContextType;
283 this->PySyclContextType_ = &PySyclContextType;
284 this->Py_SyclEventType_ = &Py_SyclEventType;
285 this->PySyclEventType_ = &PySyclEventType;
286 this->Py_SyclQueueType_ = &Py_SyclQueueType;
287 this->PySyclQueueType_ = &PySyclQueueType;
288 this->Py_MemoryType_ = &Py_MemoryType;
289 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
290 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
291 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
292 this->PyUSMArrayType_ = &PyUSMArrayType;
293 this->PySyclProgramType_ = &PySyclProgramType;
294 this->PySyclKernelType_ = &PySyclKernelType;
295
296 // SyclDevice API
297 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
298 this->SyclDevice_Make_ = SyclDevice_Make;
299
300 // SyclContext API
301 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
302 this->SyclContext_Make_ = SyclContext_Make;
303
304 // SyclEvent API
305 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
306 this->SyclEvent_Make_ = SyclEvent_Make;
307
308 // SyclQueue API
309 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
310 this->SyclQueue_Make_ = SyclQueue_Make;
311
312 // dpctl.memory API
313 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
314 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
315 this->Memory_GetContextRef_ = Memory_GetContextRef;
316 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
317 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
318 this->Memory_Make_ = Memory_Make;
319
320 // dpctl.program API
321 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
322 this->SyclKernel_Make_ = SyclKernel_Make;
323 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
324 this->SyclProgram_Make_ = SyclProgram_Make;
325
326 // constants
327 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
328 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
329 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
330 this->UAR_BOOL_ = UAR_BOOL;
331 this->UAR_BYTE_ = UAR_BYTE;
332 this->UAR_UBYTE_ = UAR_UBYTE;
333 this->UAR_SHORT_ = UAR_SHORT;
334 this->UAR_USHORT_ = UAR_USHORT;
335 this->UAR_INT_ = UAR_INT;
336 this->UAR_UINT_ = UAR_UINT;
337 this->UAR_LONG_ = UAR_LONG;
338 this->UAR_ULONG_ = UAR_ULONG;
339 this->UAR_LONGLONG_ = UAR_LONGLONG;
340 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
341 this->UAR_FLOAT_ = UAR_FLOAT;
342 this->UAR_DOUBLE_ = UAR_DOUBLE;
343 this->UAR_CFLOAT_ = UAR_CFLOAT;
344 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
345 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
346 this->UAR_HALF_ = UAR_HALF;
347
348 // deduced disjoint types
349 this->UAR_INT8_ = UAR_BYTE;
350 this->UAR_UINT8_ = UAR_UBYTE;
351 this->UAR_INT16_ = UAR_SHORT;
352 this->UAR_UINT16_ = UAR_USHORT;
353 this->UAR_INT32_ =
354 platform_typeid_lookup<std::int32_t, long, int, short>(
355 UAR_LONG, UAR_INT, UAR_SHORT);
356 this->UAR_UINT32_ =
357 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
358 unsigned short>(UAR_ULONG, UAR_UINT,
359 UAR_USHORT);
360 this->UAR_INT64_ =
361 platform_typeid_lookup<std::int64_t, long, long long, int>(
362 UAR_LONG, UAR_LONGLONG, UAR_INT);
363 this->UAR_UINT64_ =
364 platform_typeid_lookup<std::uint64_t, unsigned long,
365 unsigned long long, unsigned int>(
366 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
367
368 // create shared pointers to python objects used in type-casters
369 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
370 sycl::queue q_{};
371 PySyclQueueObject *py_q_tmp =
372 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
373 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
374 reinterpret_cast<PyObject *>(py_q_tmp));
375
376 default_sycl_queue_ = std::shared_ptr<py::object>(
377 new py::object(py_sycl_queue), Deleter{});
378
379 py::module_ mod_memory = py::module_::import("dpctl.memory");
380 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
381 as_usm_memory_ = std::shared_ptr<py::object>(
382 new py::object{py_as_usm_memory}, Deleter{});
383
384 auto mem_kl = mod_memory.attr("MemoryUSMHost");
385 const py::object &py_default_usm_memory =
386 mem_kl(1, py::arg("queue") = py_sycl_queue);
387 default_usm_memory_ = std::shared_ptr<py::object>(
388 new py::object{py_default_usm_memory}, Deleter{});
389
390 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
391 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
392
393 const py::object &py_default_usm_ndarray =
394 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
395 py::arg("buffer") = py_default_usm_memory);
396
397 default_usm_ndarray_ = std::shared_ptr<py::object>(
398 new py::object{py_default_usm_ndarray}, Deleter{});
399 }
400
401 dpctl_capi(dpctl_capi const &) = default;
402 dpctl_capi &operator=(dpctl_capi const &) = default;
403 dpctl_capi &operator=(dpctl_capi &&) = default;
404
405}; // struct dpctl_capi
406} // namespace detail
407} // namespace dpctl
408
409namespace pybind11::detail
410{
411#define DPCTL_TYPE_CASTER(type, py_name) \
412protected: \
413 std::unique_ptr<type> value; \
414 \
415public: \
416 static constexpr auto name = py_name; \
417 template < \
418 typename T_, \
419 ::pybind11::detail::enable_if_t< \
420 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
421 int> = 0> \
422 static ::pybind11::handle cast(T_ *src, \
423 ::pybind11::return_value_policy policy, \
424 ::pybind11::handle parent) \
425 { \
426 if (!src) \
427 return ::pybind11::none().release(); \
428 if (policy == ::pybind11::return_value_policy::take_ownership) { \
429 auto h = cast(std::move(*src), policy, parent); \
430 delete src; \
431 return h; \
432 } \
433 return cast(*src, policy, parent); \
434 } \
435 operator type *() \
436 { \
437 return value.get(); \
438 } /* NOLINT(bugprone-macro-parentheses) */ \
439 operator type &() \
440 { \
441 return *value; \
442 } /* NOLINT(bugprone-macro-parentheses) */ \
443 operator type &&() && \
444 { \
445 return std::move(*value); \
446 } /* NOLINT(bugprone-macro-parentheses) */ \
447 template <typename T_> \
448 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
449
450/* This type caster associates ``sycl::queue`` C++ class with
451 * :class:`dpctl.SyclQueue` for the purposes of generation of
452 * Python bindings by pybind11.
453 */
454template <>
455struct type_caster<sycl::queue>
456{
457public:
458 bool load(handle src, bool)
459 {
460 PyObject *source = src.ptr();
461 auto const &api = ::dpctl::detail::dpctl_capi::get();
462 if (api.PySyclQueue_Check_(source)) {
463 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
464 reinterpret_cast<PySyclQueueObject *>(source));
465 value = std::make_unique<sycl::queue>(
466 *(reinterpret_cast<sycl::queue *>(QRef)));
467 return true;
468 }
469 else {
470 throw py::type_error(
471 "Input is of unexpected type, expected dpctl.SyclQueue");
472 }
473 }
474
475 static handle cast(sycl::queue src, return_value_policy, handle)
476 {
477 auto const &api = ::dpctl::detail::dpctl_capi::get();
478 auto tmp =
479 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
480 return handle(reinterpret_cast<PyObject *>(tmp));
481 }
482
483 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
484};
485
486/* This type caster associates ``sycl::device`` C++ class with
487 * :class:`dpctl.SyclDevice` for the purposes of generation of
488 * Python bindings by pybind11.
489 */
490template <>
491struct type_caster<sycl::device>
492{
493public:
494 bool load(handle src, bool)
495 {
496 PyObject *source = src.ptr();
497 auto const &api = ::dpctl::detail::dpctl_capi::get();
498 if (api.PySyclDevice_Check_(source)) {
499 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
500 reinterpret_cast<PySyclDeviceObject *>(source));
501 value = std::make_unique<sycl::device>(
502 *(reinterpret_cast<sycl::device *>(DRef)));
503 return true;
504 }
505 else {
506 throw py::type_error(
507 "Input is of unexpected type, expected dpctl.SyclDevice");
508 }
509 }
510
511 static handle cast(sycl::device src, return_value_policy, handle)
512 {
513 auto const &api = ::dpctl::detail::dpctl_capi::get();
514 auto tmp =
515 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
516 return handle(reinterpret_cast<PyObject *>(tmp));
517 }
518
519 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
520};
521
522/* This type caster associates ``sycl::context`` C++ class with
523 * :class:`dpctl.SyclContext` for the purposes of generation of
524 * Python bindings by pybind11.
525 */
526template <>
527struct type_caster<sycl::context>
528{
529public:
530 bool load(handle src, bool)
531 {
532 PyObject *source = src.ptr();
533 auto const &api = ::dpctl::detail::dpctl_capi::get();
534 if (api.PySyclContext_Check_(source)) {
535 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
536 reinterpret_cast<PySyclContextObject *>(source));
537 value = std::make_unique<sycl::context>(
538 *(reinterpret_cast<sycl::context *>(CRef)));
539 return true;
540 }
541 else {
542 throw py::type_error(
543 "Input is of unexpected type, expected dpctl.SyclContext");
544 }
545 }
546
547 static handle cast(sycl::context src, return_value_policy, handle)
548 {
549 auto const &api = ::dpctl::detail::dpctl_capi::get();
550 auto tmp =
551 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
552 return handle(reinterpret_cast<PyObject *>(tmp));
553 }
554
555 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
556};
557
558/* This type caster associates ``sycl::event`` C++ class with
559 * :class:`dpctl.SyclEvent` for the purposes of generation of
560 * Python bindings by pybind11.
561 */
562template <>
563struct type_caster<sycl::event>
564{
565public:
566 bool load(handle src, bool)
567 {
568 PyObject *source = src.ptr();
569 auto const &api = ::dpctl::detail::dpctl_capi::get();
570 if (api.PySyclEvent_Check_(source)) {
571 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
572 reinterpret_cast<PySyclEventObject *>(source));
573 value = std::make_unique<sycl::event>(
574 *(reinterpret_cast<sycl::event *>(ERef)));
575 return true;
576 }
577 else {
578 throw py::type_error(
579 "Input is of unexpected type, expected dpctl.SyclEvent");
580 }
581 }
582
583 static handle cast(sycl::event src, return_value_policy, handle)
584 {
585 auto const &api = ::dpctl::detail::dpctl_capi::get();
586 auto tmp =
587 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
588 return handle(reinterpret_cast<PyObject *>(tmp));
589 }
590
591 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
592};
593
594/* This type caster associates ``sycl::kernel`` C++ class with
595 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
596 * Python bindings by pybind11.
597 */
598template <>
599struct type_caster<sycl::kernel>
600{
601public:
602 bool load(handle src, bool)
603 {
604 PyObject *source = src.ptr();
605 auto const &api = ::dpctl::detail::dpctl_capi::get();
606 if (api.PySyclKernel_Check_(source)) {
607 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
608 reinterpret_cast<PySyclKernelObject *>(source));
609 value = std::make_unique<sycl::kernel>(
610 *(reinterpret_cast<sycl::kernel *>(KRef)));
611 return true;
612 }
613 else {
614 throw py::type_error("Input is of unexpected type, expected "
615 "dpctl.program.SyclKernel");
616 }
617 }
618
619 static handle cast(sycl::kernel src, return_value_policy, handle)
620 {
621 auto const &api = ::dpctl::detail::dpctl_capi::get();
622 auto tmp =
623 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
624 "dpctl4pybind11_kernel");
625 return handle(reinterpret_cast<PyObject *>(tmp));
626 }
627
628 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
629};
630
631/* This type caster associates
632 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
633 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
634 * Python bindings by pybind11.
635 */
636template <>
637struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
638{
639public:
640 bool load(handle src, bool)
641 {
642 PyObject *source = src.ptr();
643 auto const &api = ::dpctl::detail::dpctl_capi::get();
644 if (api.PySyclProgram_Check_(source)) {
645 DPCTLSyclKernelBundleRef KBRef =
646 api.SyclProgram_GetKernelBundleRef_(
647 reinterpret_cast<PySyclProgramObject *>(source));
648 value = std::make_unique<
649 sycl::kernel_bundle<sycl::bundle_state::executable>>(
650 *(reinterpret_cast<
651 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
652 KBRef)));
653 return true;
654 }
655 else {
656 throw py::type_error("Input is of unexpected type, expected "
657 "dpctl.program.SyclProgram");
658 }
659 }
660
661 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
662 return_value_policy,
663 handle)
664 {
665 auto const &api = ::dpctl::detail::dpctl_capi::get();
666 auto tmp = api.SyclProgram_Make_(
667 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
668 return handle(reinterpret_cast<PyObject *>(tmp));
669 }
670
671 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
672 _("dpctl.program.SyclProgram"));
673};
674
675/* This type caster associates
676 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
677 * of generation of Python bindings by pybind11.
678 */
679template <>
680struct type_caster<sycl::half>
681{
682public:
683 bool load(handle src, bool convert)
684 {
685 double py_value;
686
687 if (!src) {
688 return false;
689 }
690
691 PyObject *source = src.ptr();
692
693 if (convert || PyFloat_Check(source)) {
694 py_value = PyFloat_AsDouble(source);
695 }
696 else {
697 return false;
698 }
699
700 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
701
702 if (py_err) {
703 PyErr_Clear();
704 if (convert && (PyNumber_Check(source) != 0)) {
705 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
706 return load(tmp, false);
707 }
708 return false;
709 }
710 value = static_cast<sycl::half>(py_value);
711 return true;
712 }
713
714 static handle cast(sycl::half src, return_value_policy, handle)
715 {
716 return PyFloat_FromDouble(static_cast<double>(src));
717 }
718
719 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
720};
721} // namespace pybind11::detail
722
723namespace dpctl
724{
725namespace memory
726{
727// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
728// this allows to avoid compilation error
729using pybind11::error_already_set;
730
731class usm_memory : public py::object
732{
733public:
734 PYBIND11_OBJECT_CVT(
736 py::object,
737 [](PyObject *o) -> bool {
738 return PyObject_TypeCheck(
739 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
740 0;
741 },
742 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
743
744 usm_memory()
745 : py::object(
746 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
747 borrowed_t{})
748 {
749 if (!m_ptr)
750 throw py::error_already_set();
751 }
752
756 usm_memory(void *usm_ptr,
757 std::size_t nbytes,
758 const sycl::queue &q,
759 std::shared_ptr<void> shptr)
760 {
761 auto const &api = ::dpctl::detail::dpctl_capi::get();
762 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
763 auto q_uptr = std::make_unique<sycl::queue>(q);
764 DPCTLSyclQueueRef QRef =
765 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
766
767 auto vacuous_destructor = []() {};
768 py::capsule mock_owner(vacuous_destructor);
769
770 // create memory object owned by mock_owner, it is a new reference
771 PyObject *_memory =
772 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
773 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
774
775 using py_uptrT =
776 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
777
778 if (!_memory) {
779 throw py::error_already_set();
780 }
781
782 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
783 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
784
785 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
786 // replace mock_owner capsule as the owner
787 memobj->refobj = Py_None;
788 // set opaque ptr field, usm_memory now knowns that USM is managed
789 // by smart pointer
790 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
791
792 // _memory will delete created copies of sycl::queue, and
793 // std::shared_ptr and the deleter of the shared_ptr<void> is
794 // supposed to free the USM allocation
795 m_ptr = _memory;
796 q_uptr.release();
797 memory_uptr.release();
798 }
799
800 sycl::queue get_queue() const
801 {
802 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
803 auto const &api = ::dpctl::detail::dpctl_capi::get();
804 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
805 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
806 return *obj_q;
807 }
808
809 char *get_pointer() const
810 {
811 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
812 auto const &api = ::dpctl::detail::dpctl_capi::get();
813 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
814 return reinterpret_cast<char *>(MRef);
815 }
816
817 std::size_t get_nbytes() const
818 {
819 auto const &api = ::dpctl::detail::dpctl_capi::get();
820 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
821 return api.Memory_GetNumBytes_(mem_obj);
822 }
823
824 bool is_managed_by_smart_ptr() const
825 {
826 auto const &api = ::dpctl::detail::dpctl_capi::get();
827 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
828 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
829
830 return bool(opaque_ptr);
831 }
832
833 const std::shared_ptr<void> &get_smart_ptr_owner() const
834 {
835 auto const &api = ::dpctl::detail::dpctl_capi::get();
836 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
837 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
838
839 if (opaque_ptr) {
840 auto shptr_ptr =
841 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
842 return *shptr_ptr;
843 }
844 else {
845 throw std::runtime_error(
846 "Memory object does not have smart pointer "
847 "managing lifetime of USM allocation");
848 }
849 }
850
851protected:
852 static PyObject *as_usm_memory(PyObject *o)
853 {
854 if (o == nullptr) {
855 PyErr_SetString(PyExc_ValueError,
856 "cannot create a usm_memory from a nullptr");
857 return nullptr;
858 }
859
860 auto converter =
861 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
862
863 py::object res;
864 try {
865 res = converter(py::handle(o));
866 } catch (const py::error_already_set &e) {
867 return nullptr;
868 }
869 return res.ptr();
870 }
871};
872} // end namespace memory
873
874namespace tensor
875{
876inline std::vector<py::ssize_t>
877 c_contiguous_strides(int nd,
878 const py::ssize_t *shape,
879 py::ssize_t element_size = 1)
880{
881 if (nd > 0) {
882 std::vector<py::ssize_t> c_strides(nd, element_size);
883 for (int ic = nd - 1; ic > 0;) {
884 py::ssize_t next_v = c_strides[ic] * shape[ic];
885 c_strides[--ic] = next_v;
886 }
887 return c_strides;
888 }
889 else {
890 return std::vector<py::ssize_t>();
891 }
892}
893
894inline std::vector<py::ssize_t>
895 f_contiguous_strides(int nd,
896 const py::ssize_t *shape,
897 py::ssize_t element_size = 1)
898{
899 if (nd > 0) {
900 std::vector<py::ssize_t> f_strides(nd, element_size);
901 for (int i = 0; i < nd - 1;) {
902 py::ssize_t next_v = f_strides[i] * shape[i];
903 f_strides[++i] = next_v;
904 }
905 return f_strides;
906 }
907 else {
908 return std::vector<py::ssize_t>();
909 }
910}
911
912inline std::vector<py::ssize_t>
913 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
914 py::ssize_t element_size = 1)
915{
916 return c_contiguous_strides(shape.size(), shape.data(), element_size);
917}
918
919inline std::vector<py::ssize_t>
920 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
921 py::ssize_t element_size = 1)
922{
923 return f_contiguous_strides(shape.size(), shape.data(), element_size);
924}
925
926class usm_ndarray : public py::object
927{
928public:
929 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
930 return PyObject_TypeCheck(
931 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
932 })
933
935 : py::object(
936 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
937 borrowed_t{})
938 {
939 if (!m_ptr)
940 throw py::error_already_set();
941 }
942
943 char *get_data() const
944 {
945 PyUSMArrayObject *raw_ar = usm_array_ptr();
946 return raw_ar->data_;
947 }
948
949 template <typename T>
950 T *get_data() const
951 {
952 return reinterpret_cast<T *>(get_data());
953 }
954
955 int get_ndim() const
956 {
957 PyUSMArrayObject *raw_ar = usm_array_ptr();
958 return raw_ar->nd_;
959 }
960
961 const py::ssize_t *get_shape_raw() const
962 {
963 PyUSMArrayObject *raw_ar = usm_array_ptr();
964 return raw_ar->shape_;
965 }
966
967 std::vector<py::ssize_t> get_shape_vector() const
968 {
969 auto raw_sh = get_shape_raw();
970 auto nd = get_ndim();
971
972 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
973 return shape_vector;
974 }
975
976 py::ssize_t get_shape(int i) const
977 {
978 auto shape_ptr = get_shape_raw();
979 return shape_ptr[i];
980 }
981
982 const py::ssize_t *get_strides_raw() const
983 {
984 PyUSMArrayObject *raw_ar = usm_array_ptr();
985 return raw_ar->strides_;
986 }
987
988 std::vector<py::ssize_t> get_strides_vector() const
989 {
990 auto raw_st = get_strides_raw();
991 auto nd = get_ndim();
992
993 if (raw_st == nullptr) {
994 auto is_c_contig = is_c_contiguous();
995 auto is_f_contig = is_f_contiguous();
996 auto raw_sh = get_shape_raw();
997 if (is_c_contig) {
998 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
999 return contig_strides;
1000 }
1001 else if (is_f_contig) {
1002 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1003 return contig_strides;
1004 }
1005 else {
1006 throw std::runtime_error("Invalid array encountered when "
1007 "building strides");
1008 }
1009 }
1010 else {
1011 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1012 return st_vec;
1013 }
1014 }
1015
1016 py::ssize_t get_size() const
1017 {
1018 PyUSMArrayObject *raw_ar = usm_array_ptr();
1019
1020 int ndim = raw_ar->nd_;
1021 const py::ssize_t *shape = raw_ar->shape_;
1022
1023 py::ssize_t nelems = 1;
1024 for (int i = 0; i < ndim; ++i) {
1025 nelems *= shape[i];
1026 }
1027
1028 assert(nelems >= 0);
1029 return nelems;
1030 }
1031
1032 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
1033 {
1034 PyUSMArrayObject *raw_ar = usm_array_ptr();
1035
1036 int nd = raw_ar->nd_;
1037 const py::ssize_t *shape = raw_ar->shape_;
1038 const py::ssize_t *strides = raw_ar->strides_;
1039
1040 py::ssize_t offset_min = 0;
1041 py::ssize_t offset_max = 0;
1042 if (strides == nullptr) {
1043 py::ssize_t stride(1);
1044 for (int i = 0; i < nd; ++i) {
1045 offset_max += stride * (shape[i] - 1);
1046 stride *= shape[i];
1047 }
1048 }
1049 else {
1050 for (int i = 0; i < nd; ++i) {
1051 py::ssize_t delta = strides[i] * (shape[i] - 1);
1052 if (strides[i] > 0) {
1053 offset_max += delta;
1054 }
1055 else {
1056 offset_min += delta;
1057 }
1058 }
1059 }
1060 return std::make_pair(offset_min, offset_max);
1061 }
1062
1063 sycl::queue get_queue() const
1064 {
1065 PyUSMArrayObject *raw_ar = usm_array_ptr();
1066 Py_MemoryObject *mem_obj =
1067 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1068
1069 auto const &api = ::dpctl::detail::dpctl_capi::get();
1070 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1071 return *(reinterpret_cast<sycl::queue *>(QRef));
1072 }
1073
1074 sycl::device get_device() const
1075 {
1076 PyUSMArrayObject *raw_ar = usm_array_ptr();
1077 Py_MemoryObject *mem_obj =
1078 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1079
1080 auto const &api = ::dpctl::detail::dpctl_capi::get();
1081 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1082 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1083 }
1084
1085 int get_typenum() const
1086 {
1087 PyUSMArrayObject *raw_ar = usm_array_ptr();
1088 return raw_ar->typenum_;
1089 }
1090
1091 int get_flags() const
1092 {
1093 PyUSMArrayObject *raw_ar = usm_array_ptr();
1094 return raw_ar->flags_;
1095 }
1096
1097 int get_elemsize() const
1098 {
1099 int typenum = get_typenum();
1100 auto const &api = ::dpctl::detail::dpctl_capi::get();
1101
1102 // Lookup table for element sizes based on typenum
1103 if (typenum == api.UAR_BOOL_)
1104 return 1;
1105 if (typenum == api.UAR_BYTE_)
1106 return 1;
1107 if (typenum == api.UAR_UBYTE_)
1108 return 1;
1109 if (typenum == api.UAR_SHORT_)
1110 return 2;
1111 if (typenum == api.UAR_USHORT_)
1112 return 2;
1113 if (typenum == api.UAR_INT_)
1114 return 4;
1115 if (typenum == api.UAR_UINT_)
1116 return 4;
1117 if (typenum == api.UAR_LONG_)
1118 return sizeof(long);
1119 if (typenum == api.UAR_ULONG_)
1120 return sizeof(unsigned long);
1121 if (typenum == api.UAR_LONGLONG_)
1122 return 8;
1123 if (typenum == api.UAR_ULONGLONG_)
1124 return 8;
1125 if (typenum == api.UAR_FLOAT_)
1126 return 4;
1127 if (typenum == api.UAR_DOUBLE_)
1128 return 8;
1129 if (typenum == api.UAR_CFLOAT_)
1130 return 8;
1131 if (typenum == api.UAR_CDOUBLE_)
1132 return 16;
1133 if (typenum == api.UAR_HALF_)
1134 return 2;
1135
1136 return 0; // Unknown type
1137 }
1138
1139 bool is_c_contiguous() const
1140 {
1141 int flags = get_flags();
1142 auto const &api = ::dpctl::detail::dpctl_capi::get();
1143 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1144 }
1145
1146 bool is_f_contiguous() const
1147 {
1148 int flags = get_flags();
1149 auto const &api = ::dpctl::detail::dpctl_capi::get();
1150 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1151 }
1152
1153 bool is_writable() const
1154 {
1155 int flags = get_flags();
1156 auto const &api = ::dpctl::detail::dpctl_capi::get();
1157 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1158 }
1159
1161 py::object get_usm_data() const
1162 {
1163 PyUSMArrayObject *raw_ar = usm_array_ptr();
1164 // base_ is the Memory object - return new reference
1165 PyObject *usm_data = raw_ar->base_;
1166 Py_XINCREF(usm_data);
1167
1168 // pass reference ownership to py::object
1169 return py::reinterpret_steal<py::object>(usm_data);
1170 }
1171
1172 bool is_managed_by_smart_ptr() const
1173 {
1174 PyUSMArrayObject *raw_ar = usm_array_ptr();
1175 PyObject *usm_data = raw_ar->base_;
1176
1177 auto const &api = ::dpctl::detail::dpctl_capi::get();
1178 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1179 return false;
1180 }
1181
1182 Py_MemoryObject *mem_obj =
1183 reinterpret_cast<Py_MemoryObject *>(usm_data);
1184 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1185
1186 return bool(opaque_ptr);
1187 }
1188
1189 const std::shared_ptr<void> &get_smart_ptr_owner() const
1190 {
1191 PyUSMArrayObject *raw_ar = usm_array_ptr();
1192 PyObject *usm_data = raw_ar->base_;
1193
1194 auto const &api = ::dpctl::detail::dpctl_capi::get();
1195
1196 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1197 throw std::runtime_error(
1198 "usm_ndarray object does not have Memory object "
1199 "managing lifetime of USM allocation");
1200 }
1201
1202 Py_MemoryObject *mem_obj =
1203 reinterpret_cast<Py_MemoryObject *>(usm_data);
1204 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1205
1206 if (opaque_ptr) {
1207 auto shptr_ptr =
1208 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1209 return *shptr_ptr;
1210 }
1211 else {
1212 throw std::runtime_error(
1213 "Memory object underlying usm_ndarray does not have "
1214 "smart pointer managing lifetime of USM allocation");
1215 }
1216 }
1217
1218private:
1219 PyUSMArrayObject *usm_array_ptr() const
1220 {
1221 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1222 }
1223};
1224} // end namespace tensor
1225
1226namespace utils
1227{
1228namespace detail
1229{
1231{
1232
1233 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1234 {
1235 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1236 const auto &usm_memory_inst =
1237 py::cast<dpctl::memory::usm_memory>(h);
1238 return usm_memory_inst.is_managed_by_smart_ptr();
1239 }
1240 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1241 const auto &usm_array_inst =
1242 py::cast<dpctl::tensor::usm_ndarray>(h);
1243 return usm_array_inst.is_managed_by_smart_ptr();
1244 }
1245
1246 return false;
1247 }
1248
1249 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1250 {
1251 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1252 const auto &usm_memory_inst =
1253 py::cast<dpctl::memory::usm_memory>(h);
1254 return usm_memory_inst.get_smart_ptr_owner();
1255 }
1256 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1257 const auto &usm_array_inst =
1258 py::cast<dpctl::tensor::usm_ndarray>(h);
1259 return usm_array_inst.get_smart_ptr_owner();
1260 }
1261
1262 throw std::runtime_error(
1263 "Attempted extraction of shared_ptr on an unrecognized type");
1264 }
1265};
1266} // end of namespace detail
1267
1268template <std::size_t num>
1269sycl::event keep_args_alive(sycl::queue &q,
1270 const py::object (&py_objs)[num],
1271 const std::vector<sycl::event> &depends = {})
1272{
1273 std::size_t n_objects_held = 0;
1274 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1275
1276 std::size_t n_usm_owners_held = 0;
1277 std::array<std::shared_ptr<void>, num> shp_usm{};
1278
1279 for (std::size_t i = 0; i < num; ++i) {
1280 const auto &py_obj_i = py_objs[i];
1281 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1282 const auto &shp =
1283 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1284 shp_usm[n_usm_owners_held] = shp;
1285 ++n_usm_owners_held;
1286 }
1287 else {
1288 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1289 shp_arr[n_objects_held]->inc_ref();
1290 ++n_objects_held;
1291 }
1292 }
1293
1294 bool use_depends = true;
1295 sycl::event host_task_ev;
1296
1297 if (n_usm_owners_held > 0) {
1298 host_task_ev = q.submit([&](sycl::handler &cgh) {
1299 if (use_depends) {
1300 cgh.depends_on(depends);
1301 use_depends = false;
1302 }
1303 else {
1304 cgh.depends_on(host_task_ev);
1305 }
1306 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1307 // no body, but shared pointers are captured in
1308 // the lambda, ensuring that USM allocation is
1309 // kept alive
1310 });
1311 });
1312 }
1313
1314 if (n_objects_held > 0) {
1315 host_task_ev = q.submit([&](sycl::handler &cgh) {
1316 if (use_depends) {
1317 cgh.depends_on(depends);
1318 use_depends = false;
1319 }
1320 else {
1321 cgh.depends_on(host_task_ev);
1322 }
1323 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1324 py::gil_scoped_acquire acquire;
1325
1326 for (std::size_t i = 0; i < n_objects_held; ++i) {
1327 shp_arr[i]->dec_ref();
1328 }
1329 });
1330 });
1331 }
1332
1333 return host_task_ev;
1334}
1335
1338template <std::size_t num>
1339bool queues_are_compatible(const sycl::queue &exec_q,
1340 const sycl::queue (&alloc_qs)[num])
1341{
1342 for (std::size_t i = 0; i < num; ++i) {
1343
1344 if (exec_q != alloc_qs[i]) {
1345 return false;
1346 }
1347 }
1348 return true;
1349}
1350
1353template <std::size_t num>
1354bool queues_are_compatible(const sycl::queue &exec_q,
1355 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1356{
1357 for (std::size_t i = 0; i < num; ++i) {
1358
1359 if (exec_q != arrs[i].get_queue()) {
1360 return false;
1361 }
1362 }
1363 return true;
1364}
1365} // end namespace utils
1366} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.