DPNP C++ backend kernel library 0.20.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dpnp4pybind11.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// Include dpctl C-API headers
32#include "dpctl_capi.h"
33
34// Include generated Cython headers for usm_ndarray
35// (struct definition and constants only)
36#include "dpnp/tensor/_usmarray.h"
37#include "dpnp/tensor/_usmarray_api.h"
38
39#include <array>
40#include <cassert>
41#include <complex>
42#include <cstddef> // for std::size_t for C++ linkage
43#include <cstdint>
44#include <memory>
45#include <stddef.h> // for size_t for C linkage
46#include <stdexcept>
47#include <type_traits>
48#include <utility>
49#include <vector>
50
51#include <pybind11/pybind11.h>
52
53#include <sycl/sycl.hpp>
54
55namespace py = pybind11;
56
57namespace dpctl
58{
59namespace detail
60{
61// Lookup a type according to its size, and return a value corresponding to the
62// NumPy typenum.
63template <typename Concrete>
64constexpr int platform_typeid_lookup()
65{
66 return -1;
67}
68
69template <typename Concrete, typename T, typename... Ts, typename... Ints>
70constexpr int platform_typeid_lookup(int I, Ints... Is)
71{
72 return sizeof(Concrete) == sizeof(T)
73 ? I
74 : platform_typeid_lookup<Concrete, Ts...>(Is...);
75}
76
78{
79public:
80 // dpctl type objects
81 PyTypeObject *Py_SyclDeviceType_;
82 PyTypeObject *PySyclDeviceType_;
83 PyTypeObject *Py_SyclContextType_;
84 PyTypeObject *PySyclContextType_;
85 PyTypeObject *Py_SyclEventType_;
86 PyTypeObject *PySyclEventType_;
87 PyTypeObject *Py_SyclQueueType_;
88 PyTypeObject *PySyclQueueType_;
89 PyTypeObject *Py_MemoryType_;
90 PyTypeObject *PyMemoryUSMDeviceType_;
91 PyTypeObject *PyMemoryUSMSharedType_;
92 PyTypeObject *PyMemoryUSMHostType_;
93 PyTypeObject *PyUSMArrayType_;
94 PyTypeObject *PySyclProgramType_;
95 PyTypeObject *PySyclKernelType_;
96
97 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
98 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
99
100 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
101 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
102
103 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
104 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
105
106 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
107 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
108
109 // memory
110 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
111 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
112 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
113 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
114 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
115 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
116 size_t,
117 DPCTLSyclQueueRef,
118 PyObject *);
119
120 // program
121 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
122 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef, const char *);
123
124 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
125 PySyclProgramObject *);
126 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
127
128 int USM_ARRAY_C_CONTIGUOUS_;
129 int USM_ARRAY_F_CONTIGUOUS_;
130 int USM_ARRAY_WRITABLE_;
131 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
132 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
133 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
134 UAR_HALF_;
135 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
136 UAR_INT64_, UAR_UINT64_;
137
138 bool PySyclDevice_Check_(PyObject *obj) const
139 {
140 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
141 }
142 bool PySyclContext_Check_(PyObject *obj) const
143 {
144 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
145 }
146 bool PySyclEvent_Check_(PyObject *obj) const
147 {
148 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
149 }
150 bool PySyclQueue_Check_(PyObject *obj) const
151 {
152 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
153 }
154 bool PySyclKernel_Check_(PyObject *obj) const
155 {
156 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
157 }
158 bool PySyclProgram_Check_(PyObject *obj) const
159 {
160 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
161 }
162
164 {
165 as_usm_memory_.reset();
166 default_usm_ndarray_.reset();
167 default_usm_memory_.reset();
168 default_sycl_queue_.reset();
169 };
170
171 static auto &get()
172 {
173 static dpctl_capi api{};
174 return api;
175 }
176
177 py::object default_sycl_queue_pyobj() { return *default_sycl_queue_; }
178 py::object default_usm_memory_pyobj() { return *default_usm_memory_; }
179 py::object default_usm_ndarray_pyobj() { return *default_usm_ndarray_; }
180 py::object as_usm_memory_pyobj() { return *as_usm_memory_; }
181
182private:
183 struct Deleter
184 {
185 void operator()(py::object *p) const
186 {
187 const bool initialized = Py_IsInitialized();
188#if PY_VERSION_HEX < 0x30d0000
189 const bool finalizing = _Py_IsFinalizing();
190#else
191 const bool finalizing = Py_IsFinalizing();
192#endif
193 const bool guard = initialized && !finalizing;
194
195 if (guard) {
196 delete p;
197 }
198 }
199 };
200
201 std::shared_ptr<py::object> default_sycl_queue_;
202 std::shared_ptr<py::object> default_usm_memory_;
203 std::shared_ptr<py::object> default_usm_ndarray_;
204 std::shared_ptr<py::object> as_usm_memory_;
205
206 dpctl_capi()
207 : Py_SyclDeviceType_(nullptr), PySyclDeviceType_(nullptr),
208 Py_SyclContextType_(nullptr), PySyclContextType_(nullptr),
209 Py_SyclEventType_(nullptr), PySyclEventType_(nullptr),
210 Py_SyclQueueType_(nullptr), PySyclQueueType_(nullptr),
211 Py_MemoryType_(nullptr), PyMemoryUSMDeviceType_(nullptr),
212 PyMemoryUSMSharedType_(nullptr), PyMemoryUSMHostType_(nullptr),
213 PyUSMArrayType_(nullptr), PySyclProgramType_(nullptr),
214 PySyclKernelType_(nullptr), SyclDevice_GetDeviceRef_(nullptr),
215 SyclDevice_Make_(nullptr), SyclContext_GetContextRef_(nullptr),
216 SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
217 SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
218 SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
219 Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
220 Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
221 Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
222 SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
223 SyclProgram_Make_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
224 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
225 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
226 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
227 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
228 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
229 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
230 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
231 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
232 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
233
234 {
235 // Import dpctl C-API
236 // (device, context, event, queue, memory, program)
237 import_dpctl();
238 // Import dpnp tensor module for PyUSMArrayType
239 import_dpnp__tensor___usmarray();
240
241 // Python type objects for classes implemented by dpctl
242 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
243 this->PySyclDeviceType_ = &PySyclDeviceType;
244 this->Py_SyclContextType_ = &Py_SyclContextType;
245 this->PySyclContextType_ = &PySyclContextType;
246 this->Py_SyclEventType_ = &Py_SyclEventType;
247 this->PySyclEventType_ = &PySyclEventType;
248 this->Py_SyclQueueType_ = &Py_SyclQueueType;
249 this->PySyclQueueType_ = &PySyclQueueType;
250 this->Py_MemoryType_ = &Py_MemoryType;
251 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
252 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
253 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
254 this->PyUSMArrayType_ = &PyUSMArrayType;
255 this->PySyclProgramType_ = &PySyclProgramType;
256 this->PySyclKernelType_ = &PySyclKernelType;
257
258 // SyclDevice API
259 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
260 this->SyclDevice_Make_ = SyclDevice_Make;
261
262 // SyclContext API
263 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
264 this->SyclContext_Make_ = SyclContext_Make;
265
266 // SyclEvent API
267 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
268 this->SyclEvent_Make_ = SyclEvent_Make;
269
270 // SyclQueue API
271 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
272 this->SyclQueue_Make_ = SyclQueue_Make;
273
274 // dpctl.memory API
275 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
276 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
277 this->Memory_GetContextRef_ = Memory_GetContextRef;
278 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
279 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
280 this->Memory_Make_ = Memory_Make;
281
282 // dpctl.program API
283 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
284 this->SyclKernel_Make_ = SyclKernel_Make;
285 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
286 this->SyclProgram_Make_ = SyclProgram_Make;
287
288 // constants
289 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
290 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
291 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
292 this->UAR_BOOL_ = UAR_BOOL;
293 this->UAR_BYTE_ = UAR_BYTE;
294 this->UAR_UBYTE_ = UAR_UBYTE;
295 this->UAR_SHORT_ = UAR_SHORT;
296 this->UAR_USHORT_ = UAR_USHORT;
297 this->UAR_INT_ = UAR_INT;
298 this->UAR_UINT_ = UAR_UINT;
299 this->UAR_LONG_ = UAR_LONG;
300 this->UAR_ULONG_ = UAR_ULONG;
301 this->UAR_LONGLONG_ = UAR_LONGLONG;
302 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
303 this->UAR_FLOAT_ = UAR_FLOAT;
304 this->UAR_DOUBLE_ = UAR_DOUBLE;
305 this->UAR_CFLOAT_ = UAR_CFLOAT;
306 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
307 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
308 this->UAR_HALF_ = UAR_HALF;
309
310 // deduced disjoint types
311 this->UAR_INT8_ = UAR_BYTE;
312 this->UAR_UINT8_ = UAR_UBYTE;
313 this->UAR_INT16_ = UAR_SHORT;
314 this->UAR_UINT16_ = UAR_USHORT;
315 this->UAR_INT32_ =
316 platform_typeid_lookup<std::int32_t, long, int, short>(
317 UAR_LONG, UAR_INT, UAR_SHORT);
318 this->UAR_UINT32_ =
319 platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int,
320 unsigned short>(UAR_ULONG, UAR_UINT,
321 UAR_USHORT);
322 this->UAR_INT64_ =
323 platform_typeid_lookup<std::int64_t, long, long long, int>(
324 UAR_LONG, UAR_LONGLONG, UAR_INT);
325 this->UAR_UINT64_ =
326 platform_typeid_lookup<std::uint64_t, unsigned long,
327 unsigned long long, unsigned int>(
328 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
329
330 // create shared pointers to python objects used in type-casters
331 // for dpctl::memory::usm_memory and dpctl::tensor::usm_ndarray
332 sycl::queue q_{};
333 PySyclQueueObject *py_q_tmp =
334 SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&q_));
335 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
336 reinterpret_cast<PyObject *>(py_q_tmp));
337
338 default_sycl_queue_ = std::shared_ptr<py::object>(
339 new py::object(py_sycl_queue), Deleter{});
340
341 py::module_ mod_memory = py::module_::import("dpctl.memory");
342 const py::object &py_as_usm_memory = mod_memory.attr("as_usm_memory");
343 as_usm_memory_ = std::shared_ptr<py::object>(
344 new py::object{py_as_usm_memory}, Deleter{});
345
346 auto mem_kl = mod_memory.attr("MemoryUSMHost");
347 const py::object &py_default_usm_memory =
348 mem_kl(1, py::arg("queue") = py_sycl_queue);
349 default_usm_memory_ = std::shared_ptr<py::object>(
350 new py::object{py_default_usm_memory}, Deleter{});
351
352 py::module_ mod_usmarray = py::module_::import("dpnp.tensor._usmarray");
353 auto tensor_kl = mod_usmarray.attr("usm_ndarray");
354
355 const py::object &py_default_usm_ndarray =
356 tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
357 py::arg("buffer") = py_default_usm_memory);
358
359 default_usm_ndarray_ = std::shared_ptr<py::object>(
360 new py::object{py_default_usm_ndarray}, Deleter{});
361 }
362
363 dpctl_capi(dpctl_capi const &) = default;
364 dpctl_capi &operator=(dpctl_capi const &) = default;
365 dpctl_capi &operator=(dpctl_capi &&) = default;
366
367}; // struct dpctl_capi
368} // namespace detail
369} // namespace dpctl
370
371namespace pybind11::detail
372{
373#define DPCTL_TYPE_CASTER(type, py_name) \
374protected: \
375 std::unique_ptr<type> value; \
376 \
377public: \
378 static constexpr auto name = py_name; \
379 template < \
380 typename T_, \
381 ::pybind11::detail::enable_if_t< \
382 std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
383 int> = 0> \
384 static ::pybind11::handle cast(T_ *src, \
385 ::pybind11::return_value_policy policy, \
386 ::pybind11::handle parent) \
387 { \
388 if (!src) \
389 return ::pybind11::none().release(); \
390 if (policy == ::pybind11::return_value_policy::take_ownership) { \
391 auto h = cast(std::move(*src), policy, parent); \
392 delete src; \
393 return h; \
394 } \
395 return cast(*src, policy, parent); \
396 } \
397 operator type *() \
398 { \
399 return value.get(); \
400 } /* NOLINT(bugprone-macro-parentheses) */ \
401 operator type &() \
402 { \
403 return *value; \
404 } /* NOLINT(bugprone-macro-parentheses) */ \
405 operator type &&() && \
406 { \
407 return std::move(*value); \
408 } /* NOLINT(bugprone-macro-parentheses) */ \
409 template <typename T_> \
410 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
411
412/* This type caster associates ``sycl::queue`` C++ class with
413 * :class:`dpctl.SyclQueue` for the purposes of generation of
414 * Python bindings by pybind11.
415 */
416template <>
417struct type_caster<sycl::queue>
418{
419public:
420 bool load(handle src, bool)
421 {
422 PyObject *source = src.ptr();
423 auto const &api = ::dpctl::detail::dpctl_capi::get();
424 if (api.PySyclQueue_Check_(source)) {
425 DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
426 reinterpret_cast<PySyclQueueObject *>(source));
427 value = std::make_unique<sycl::queue>(
428 *(reinterpret_cast<sycl::queue *>(QRef)));
429 return true;
430 }
431 else {
432 throw py::type_error(
433 "Input is of unexpected type, expected dpctl.SyclQueue");
434 }
435 }
436
437 static handle cast(sycl::queue src, return_value_policy, handle)
438 {
439 auto const &api = ::dpctl::detail::dpctl_capi::get();
440 auto tmp =
441 api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
442 return handle(reinterpret_cast<PyObject *>(tmp));
443 }
444
445 DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
446};
447
448/* This type caster associates ``sycl::device`` C++ class with
449 * :class:`dpctl.SyclDevice` for the purposes of generation of
450 * Python bindings by pybind11.
451 */
452template <>
453struct type_caster<sycl::device>
454{
455public:
456 bool load(handle src, bool)
457 {
458 PyObject *source = src.ptr();
459 auto const &api = ::dpctl::detail::dpctl_capi::get();
460 if (api.PySyclDevice_Check_(source)) {
461 DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
462 reinterpret_cast<PySyclDeviceObject *>(source));
463 value = std::make_unique<sycl::device>(
464 *(reinterpret_cast<sycl::device *>(DRef)));
465 return true;
466 }
467 else {
468 throw py::type_error(
469 "Input is of unexpected type, expected dpctl.SyclDevice");
470 }
471 }
472
473 static handle cast(sycl::device src, return_value_policy, handle)
474 {
475 auto const &api = ::dpctl::detail::dpctl_capi::get();
476 auto tmp =
477 api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
478 return handle(reinterpret_cast<PyObject *>(tmp));
479 }
480
481 DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
482};
483
484/* This type caster associates ``sycl::context`` C++ class with
485 * :class:`dpctl.SyclContext` for the purposes of generation of
486 * Python bindings by pybind11.
487 */
488template <>
489struct type_caster<sycl::context>
490{
491public:
492 bool load(handle src, bool)
493 {
494 PyObject *source = src.ptr();
495 auto const &api = ::dpctl::detail::dpctl_capi::get();
496 if (api.PySyclContext_Check_(source)) {
497 DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
498 reinterpret_cast<PySyclContextObject *>(source));
499 value = std::make_unique<sycl::context>(
500 *(reinterpret_cast<sycl::context *>(CRef)));
501 return true;
502 }
503 else {
504 throw py::type_error(
505 "Input is of unexpected type, expected dpctl.SyclContext");
506 }
507 }
508
509 static handle cast(sycl::context src, return_value_policy, handle)
510 {
511 auto const &api = ::dpctl::detail::dpctl_capi::get();
512 auto tmp =
513 api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
514 return handle(reinterpret_cast<PyObject *>(tmp));
515 }
516
517 DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
518};
519
520/* This type caster associates ``sycl::event`` C++ class with
521 * :class:`dpctl.SyclEvent` for the purposes of generation of
522 * Python bindings by pybind11.
523 */
524template <>
525struct type_caster<sycl::event>
526{
527public:
528 bool load(handle src, bool)
529 {
530 PyObject *source = src.ptr();
531 auto const &api = ::dpctl::detail::dpctl_capi::get();
532 if (api.PySyclEvent_Check_(source)) {
533 DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
534 reinterpret_cast<PySyclEventObject *>(source));
535 value = std::make_unique<sycl::event>(
536 *(reinterpret_cast<sycl::event *>(ERef)));
537 return true;
538 }
539 else {
540 throw py::type_error(
541 "Input is of unexpected type, expected dpctl.SyclEvent");
542 }
543 }
544
545 static handle cast(sycl::event src, return_value_policy, handle)
546 {
547 auto const &api = ::dpctl::detail::dpctl_capi::get();
548 auto tmp =
549 api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
550 return handle(reinterpret_cast<PyObject *>(tmp));
551 }
552
553 DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
554};
555
556/* This type caster associates ``sycl::kernel`` C++ class with
557 * :class:`dpctl.program.SyclKernel` for the purposes of generation of
558 * Python bindings by pybind11.
559 */
560template <>
561struct type_caster<sycl::kernel>
562{
563public:
564 bool load(handle src, bool)
565 {
566 PyObject *source = src.ptr();
567 auto const &api = ::dpctl::detail::dpctl_capi::get();
568 if (api.PySyclKernel_Check_(source)) {
569 DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
570 reinterpret_cast<PySyclKernelObject *>(source));
571 value = std::make_unique<sycl::kernel>(
572 *(reinterpret_cast<sycl::kernel *>(KRef)));
573 return true;
574 }
575 else {
576 throw py::type_error("Input is of unexpected type, expected "
577 "dpctl.program.SyclKernel");
578 }
579 }
580
581 static handle cast(sycl::kernel src, return_value_policy, handle)
582 {
583 auto const &api = ::dpctl::detail::dpctl_capi::get();
584 auto tmp =
585 api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
586 "dpctl4pybind11_kernel");
587 return handle(reinterpret_cast<PyObject *>(tmp));
588 }
589
590 DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
591};
592
593/* This type caster associates
594 * ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
595 * :class:`dpctl.program.SyclProgram` for the purposes of generation of
596 * Python bindings by pybind11.
597 */
598template <>
599struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
600{
601public:
602 bool load(handle src, bool)
603 {
604 PyObject *source = src.ptr();
605 auto const &api = ::dpctl::detail::dpctl_capi::get();
606 if (api.PySyclProgram_Check_(source)) {
607 DPCTLSyclKernelBundleRef KBRef =
608 api.SyclProgram_GetKernelBundleRef_(
609 reinterpret_cast<PySyclProgramObject *>(source));
610 value = std::make_unique<
611 sycl::kernel_bundle<sycl::bundle_state::executable>>(
612 *(reinterpret_cast<
613 sycl::kernel_bundle<sycl::bundle_state::executable> *>(
614 KBRef)));
615 return true;
616 }
617 else {
618 throw py::type_error("Input is of unexpected type, expected "
619 "dpctl.program.SyclProgram");
620 }
621 }
622
623 static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src,
624 return_value_policy,
625 handle)
626 {
627 auto const &api = ::dpctl::detail::dpctl_capi::get();
628 auto tmp = api.SyclProgram_Make_(
629 reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
630 return handle(reinterpret_cast<PyObject *>(tmp));
631 }
632
633 DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
634 _("dpctl.program.SyclProgram"));
635};
636
637/* This type caster associates
638 * ``sycl::half`` C++ class with Python :class:`float` for the purposes
639 * of generation of Python bindings by pybind11.
640 */
641template <>
642struct type_caster<sycl::half>
643{
644public:
645 bool load(handle src, bool convert)
646 {
647 double py_value;
648
649 if (!src) {
650 return false;
651 }
652
653 PyObject *source = src.ptr();
654
655 if (convert || PyFloat_Check(source)) {
656 py_value = PyFloat_AsDouble(source);
657 }
658 else {
659 return false;
660 }
661
662 bool py_err = (py_value == double(-1)) && PyErr_Occurred();
663
664 if (py_err) {
665 PyErr_Clear();
666 if (convert && (PyNumber_Check(source) != 0)) {
667 auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
668 return load(tmp, false);
669 }
670 return false;
671 }
672 value = static_cast<sycl::half>(py_value);
673 return true;
674 }
675
676 static handle cast(sycl::half src, return_value_policy, handle)
677 {
678 return PyFloat_FromDouble(static_cast<double>(src));
679 }
680
681 PYBIND11_TYPE_CASTER(sycl::half, _("float"));
682};
683} // namespace pybind11::detail
684
685namespace dpctl
686{
687namespace memory
688{
689// since PYBIND11_OBJECT_CVT uses error_already_set without namespace,
690// this allows to avoid compilation error
691using pybind11::error_already_set;
692
693class usm_memory : public py::object
694{
695public:
696 PYBIND11_OBJECT_CVT(
698 py::object,
699 [](PyObject *o) -> bool {
700 return PyObject_TypeCheck(
701 o, ::dpctl::detail::dpctl_capi::get().Py_MemoryType_) !=
702 0;
703 },
704 [](PyObject *o) -> PyObject * { return as_usm_memory(o); })
705
706 usm_memory()
707 : py::object(
708 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(),
709 borrowed_t{})
710 {
711 if (!m_ptr)
712 throw py::error_already_set();
713 }
714
718 usm_memory(void *usm_ptr,
719 std::size_t nbytes,
720 const sycl::queue &q,
721 std::shared_ptr<void> shptr)
722 {
723 auto const &api = ::dpctl::detail::dpctl_capi::get();
724 DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr);
725 auto q_uptr = std::make_unique<sycl::queue>(q);
726 DPCTLSyclQueueRef QRef =
727 reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get());
728
729 auto vacuous_destructor = []() {};
730 py::capsule mock_owner(vacuous_destructor);
731
732 // create memory object owned by mock_owner, it is a new reference
733 PyObject *_memory =
734 api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr());
735 auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); };
736
737 using py_uptrT =
738 std::unique_ptr<PyObject, decltype(ref_count_decrementer)>;
739
740 if (!_memory) {
741 throw py::error_already_set();
742 }
743
744 auto memory_uptr = py_uptrT(_memory, ref_count_decrementer);
745 std::shared_ptr<void> *opaque_ptr = new std::shared_ptr<void>(shptr);
746
747 Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory);
748 // replace mock_owner capsule as the owner
749 memobj->refobj = Py_None;
750 // set opaque ptr field, usm_memory now knowns that USM is managed
751 // by smart pointer
752 memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr);
753
754 // _memory will delete created copies of sycl::queue, and
755 // std::shared_ptr and the deleter of the shared_ptr<void> is
756 // supposed to free the USM allocation
757 m_ptr = _memory;
758 q_uptr.release();
759 memory_uptr.release();
760 }
761
762 sycl::queue get_queue() const
763 {
764 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
765 auto const &api = ::dpctl::detail::dpctl_capi::get();
766 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
767 sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
768 return *obj_q;
769 }
770
771 char *get_pointer() const
772 {
773 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
774 auto const &api = ::dpctl::detail::dpctl_capi::get();
775 DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
776 return reinterpret_cast<char *>(MRef);
777 }
778
779 std::size_t get_nbytes() const
780 {
781 auto const &api = ::dpctl::detail::dpctl_capi::get();
782 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
783 return api.Memory_GetNumBytes_(mem_obj);
784 }
785
786 bool is_managed_by_smart_ptr() const
787 {
788 auto const &api = ::dpctl::detail::dpctl_capi::get();
789 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
790 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
791
792 return bool(opaque_ptr);
793 }
794
795 const std::shared_ptr<void> &get_smart_ptr_owner() const
796 {
797 auto const &api = ::dpctl::detail::dpctl_capi::get();
798 Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
799 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
800
801 if (opaque_ptr) {
802 auto shptr_ptr =
803 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
804 return *shptr_ptr;
805 }
806 else {
807 throw std::runtime_error(
808 "Memory object does not have smart pointer "
809 "managing lifetime of USM allocation");
810 }
811 }
812
813protected:
814 static PyObject *as_usm_memory(PyObject *o)
815 {
816 if (o == nullptr) {
817 PyErr_SetString(PyExc_ValueError,
818 "cannot create a usm_memory from a nullptr");
819 return nullptr;
820 }
821
822 auto converter =
823 ::dpctl::detail::dpctl_capi::get().as_usm_memory_pyobj();
824
825 py::object res;
826 try {
827 res = converter(py::handle(o));
828 } catch (const py::error_already_set &e) {
829 return nullptr;
830 }
831 return res.ptr();
832 }
833};
834} // end namespace memory
835
836namespace tensor
837{
838inline std::vector<py::ssize_t>
839 c_contiguous_strides(int nd,
840 const py::ssize_t *shape,
841 py::ssize_t element_size = 1)
842{
843 if (nd > 0) {
844 std::vector<py::ssize_t> c_strides(nd, element_size);
845 for (int ic = nd - 1; ic > 0;) {
846 py::ssize_t next_v = c_strides[ic] * shape[ic];
847 c_strides[--ic] = next_v;
848 }
849 return c_strides;
850 }
851 else {
852 return std::vector<py::ssize_t>();
853 }
854}
855
856inline std::vector<py::ssize_t>
857 f_contiguous_strides(int nd,
858 const py::ssize_t *shape,
859 py::ssize_t element_size = 1)
860{
861 if (nd > 0) {
862 std::vector<py::ssize_t> f_strides(nd, element_size);
863 for (int i = 0; i < nd - 1;) {
864 py::ssize_t next_v = f_strides[i] * shape[i];
865 f_strides[++i] = next_v;
866 }
867 return f_strides;
868 }
869 else {
870 return std::vector<py::ssize_t>();
871 }
872}
873
874inline std::vector<py::ssize_t>
875 c_contiguous_strides(const std::vector<py::ssize_t> &shape,
876 py::ssize_t element_size = 1)
877{
878 return c_contiguous_strides(shape.size(), shape.data(), element_size);
879}
880
881inline std::vector<py::ssize_t>
882 f_contiguous_strides(const std::vector<py::ssize_t> &shape,
883 py::ssize_t element_size = 1)
884{
885 return f_contiguous_strides(shape.size(), shape.data(), element_size);
886}
887
888class usm_ndarray : public py::object
889{
890public:
891 PYBIND11_OBJECT(usm_ndarray, py::object, [](PyObject *o) -> bool {
892 return PyObject_TypeCheck(
893 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
894 })
895
897 : py::object(
898 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
899 borrowed_t{})
900 {
901 if (!m_ptr)
902 throw py::error_already_set();
903 }
904
905 char *get_data() const
906 {
907 PyUSMArrayObject *raw_ar = usm_array_ptr();
908 return raw_ar->data_;
909 }
910
911 template <typename T>
912 T *get_data() const
913 {
914 return reinterpret_cast<T *>(get_data());
915 }
916
917 int get_ndim() const
918 {
919 PyUSMArrayObject *raw_ar = usm_array_ptr();
920 return raw_ar->nd_;
921 }
922
923 const py::ssize_t *get_shape_raw() const
924 {
925 PyUSMArrayObject *raw_ar = usm_array_ptr();
926 return raw_ar->shape_;
927 }
928
929 std::vector<py::ssize_t> get_shape_vector() const
930 {
931 auto raw_sh = get_shape_raw();
932 auto nd = get_ndim();
933
934 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
935 return shape_vector;
936 }
937
938 py::ssize_t get_shape(int i) const
939 {
940 auto shape_ptr = get_shape_raw();
941 return shape_ptr[i];
942 }
943
944 const py::ssize_t *get_strides_raw() const
945 {
946 PyUSMArrayObject *raw_ar = usm_array_ptr();
947 return raw_ar->strides_;
948 }
949
950 std::vector<py::ssize_t> get_strides_vector() const
951 {
952 auto raw_st = get_strides_raw();
953 auto nd = get_ndim();
954
955 if (raw_st == nullptr) {
956 auto is_c_contig = is_c_contiguous();
957 auto is_f_contig = is_f_contiguous();
958 auto raw_sh = get_shape_raw();
959 if (is_c_contig) {
960 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
961 return contig_strides;
962 }
963 else if (is_f_contig) {
964 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
965 return contig_strides;
966 }
967 else {
968 throw std::runtime_error("Invalid array encountered when "
969 "building strides");
970 }
971 }
972 else {
973 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
974 return st_vec;
975 }
976 }
977
978 py::ssize_t get_size() const
979 {
980 PyUSMArrayObject *raw_ar = usm_array_ptr();
981
982 int ndim = raw_ar->nd_;
983 const py::ssize_t *shape = raw_ar->shape_;
984
985 py::ssize_t nelems = 1;
986 for (int i = 0; i < ndim; ++i) {
987 nelems *= shape[i];
988 }
989
990 assert(nelems >= 0);
991 return nelems;
992 }
993
994 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
995 {
996 PyUSMArrayObject *raw_ar = usm_array_ptr();
997
998 int nd = raw_ar->nd_;
999 const py::ssize_t *shape = raw_ar->shape_;
1000 const py::ssize_t *strides = raw_ar->strides_;
1001
1002 py::ssize_t offset_min = 0;
1003 py::ssize_t offset_max = 0;
1004 if (strides == nullptr) {
1005 py::ssize_t stride(1);
1006 for (int i = 0; i < nd; ++i) {
1007 offset_max += stride * (shape[i] - 1);
1008 stride *= shape[i];
1009 }
1010 }
1011 else {
1012 for (int i = 0; i < nd; ++i) {
1013 py::ssize_t delta = strides[i] * (shape[i] - 1);
1014 if (strides[i] > 0) {
1015 offset_max += delta;
1016 }
1017 else {
1018 offset_min += delta;
1019 }
1020 }
1021 }
1022 return std::make_pair(offset_min, offset_max);
1023 }
1024
1025 sycl::queue get_queue() const
1026 {
1027 PyUSMArrayObject *raw_ar = usm_array_ptr();
1028 Py_MemoryObject *mem_obj =
1029 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1030
1031 auto const &api = ::dpctl::detail::dpctl_capi::get();
1032 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1033 return *(reinterpret_cast<sycl::queue *>(QRef));
1034 }
1035
1036 sycl::device get_device() const
1037 {
1038 PyUSMArrayObject *raw_ar = usm_array_ptr();
1039 Py_MemoryObject *mem_obj =
1040 reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
1041
1042 auto const &api = ::dpctl::detail::dpctl_capi::get();
1043 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1044 return reinterpret_cast<sycl::queue *>(QRef)->get_device();
1045 }
1046
1047 int get_typenum() const
1048 {
1049 PyUSMArrayObject *raw_ar = usm_array_ptr();
1050 return raw_ar->typenum_;
1051 }
1052
1053 int get_flags() const
1054 {
1055 PyUSMArrayObject *raw_ar = usm_array_ptr();
1056 return raw_ar->flags_;
1057 }
1058
1059 int get_elemsize() const
1060 {
1061 int typenum = get_typenum();
1062 auto const &api = ::dpctl::detail::dpctl_capi::get();
1063
1064 // Lookup table for element sizes based on typenum
1065 if (typenum == api.UAR_BOOL_)
1066 return 1;
1067 if (typenum == api.UAR_BYTE_)
1068 return 1;
1069 if (typenum == api.UAR_UBYTE_)
1070 return 1;
1071 if (typenum == api.UAR_SHORT_)
1072 return 2;
1073 if (typenum == api.UAR_USHORT_)
1074 return 2;
1075 if (typenum == api.UAR_INT_)
1076 return 4;
1077 if (typenum == api.UAR_UINT_)
1078 return 4;
1079 if (typenum == api.UAR_LONG_)
1080 return sizeof(long);
1081 if (typenum == api.UAR_ULONG_)
1082 return sizeof(unsigned long);
1083 if (typenum == api.UAR_LONGLONG_)
1084 return 8;
1085 if (typenum == api.UAR_ULONGLONG_)
1086 return 8;
1087 if (typenum == api.UAR_FLOAT_)
1088 return 4;
1089 if (typenum == api.UAR_DOUBLE_)
1090 return 8;
1091 if (typenum == api.UAR_CFLOAT_)
1092 return 8;
1093 if (typenum == api.UAR_CDOUBLE_)
1094 return 16;
1095 if (typenum == api.UAR_HALF_)
1096 return 2;
1097
1098 return 0; // Unknown type
1099 }
1100
1101 bool is_c_contiguous() const
1102 {
1103 int flags = get_flags();
1104 auto const &api = ::dpctl::detail::dpctl_capi::get();
1105 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1106 }
1107
1108 bool is_f_contiguous() const
1109 {
1110 int flags = get_flags();
1111 auto const &api = ::dpctl::detail::dpctl_capi::get();
1112 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1113 }
1114
1115 bool is_writable() const
1116 {
1117 int flags = get_flags();
1118 auto const &api = ::dpctl::detail::dpctl_capi::get();
1119 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1120 }
1121
1123 py::object get_usm_data() const
1124 {
1125 PyUSMArrayObject *raw_ar = usm_array_ptr();
1126 // base_ is the Memory object - return new reference
1127 PyObject *usm_data = raw_ar->base_;
1128 Py_XINCREF(usm_data);
1129
1130 // pass reference ownership to py::object
1131 return py::reinterpret_steal<py::object>(usm_data);
1132 }
1133
1134 bool is_managed_by_smart_ptr() const
1135 {
1136 PyUSMArrayObject *raw_ar = usm_array_ptr();
1137 PyObject *usm_data = raw_ar->base_;
1138
1139 auto const &api = ::dpctl::detail::dpctl_capi::get();
1140 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1141 return false;
1142 }
1143
1144 Py_MemoryObject *mem_obj =
1145 reinterpret_cast<Py_MemoryObject *>(usm_data);
1146 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1147
1148 return bool(opaque_ptr);
1149 }
1150
1151 const std::shared_ptr<void> &get_smart_ptr_owner() const
1152 {
1153 PyUSMArrayObject *raw_ar = usm_array_ptr();
1154 PyObject *usm_data = raw_ar->base_;
1155
1156 auto const &api = ::dpctl::detail::dpctl_capi::get();
1157
1158 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1159 throw std::runtime_error(
1160 "usm_ndarray object does not have Memory object "
1161 "managing lifetime of USM allocation");
1162 }
1163
1164 Py_MemoryObject *mem_obj =
1165 reinterpret_cast<Py_MemoryObject *>(usm_data);
1166 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1167
1168 if (opaque_ptr) {
1169 auto shptr_ptr =
1170 reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1171 return *shptr_ptr;
1172 }
1173 else {
1174 throw std::runtime_error(
1175 "Memory object underlying usm_ndarray does not have "
1176 "smart pointer managing lifetime of USM allocation");
1177 }
1178 }
1179
1180private:
1181 PyUSMArrayObject *usm_array_ptr() const
1182 {
1183 return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
1184 }
1185};
1186} // end namespace tensor
1187
1188namespace utils
1189{
1190namespace detail
1191{
1193{
1194
1195 static bool is_usm_managed_by_shared_ptr(const py::object &h)
1196 {
1197 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1198 const auto &usm_memory_inst =
1199 py::cast<dpctl::memory::usm_memory>(h);
1200 return usm_memory_inst.is_managed_by_smart_ptr();
1201 }
1202 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1203 const auto &usm_array_inst =
1204 py::cast<dpctl::tensor::usm_ndarray>(h);
1205 return usm_array_inst.is_managed_by_smart_ptr();
1206 }
1207
1208 return false;
1209 }
1210
1211 static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
1212 {
1213 if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1214 const auto &usm_memory_inst =
1215 py::cast<dpctl::memory::usm_memory>(h);
1216 return usm_memory_inst.get_smart_ptr_owner();
1217 }
1218 else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1219 const auto &usm_array_inst =
1220 py::cast<dpctl::tensor::usm_ndarray>(h);
1221 return usm_array_inst.get_smart_ptr_owner();
1222 }
1223
1224 throw std::runtime_error(
1225 "Attempted extraction of shared_ptr on an unrecognized type");
1226 }
1227};
1228} // end of namespace detail
1229
1230template <std::size_t num>
1231sycl::event keep_args_alive(sycl::queue &q,
1232 const py::object (&py_objs)[num],
1233 const std::vector<sycl::event> &depends = {})
1234{
1235 std::size_t n_objects_held = 0;
1236 std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1237
1238 std::size_t n_usm_owners_held = 0;
1239 std::array<std::shared_ptr<void>, num> shp_usm{};
1240
1241 for (std::size_t i = 0; i < num; ++i) {
1242 const auto &py_obj_i = py_objs[i];
1243 if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1244 const auto &shp =
1245 detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1246 shp_usm[n_usm_owners_held] = shp;
1247 ++n_usm_owners_held;
1248 }
1249 else {
1250 shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1251 shp_arr[n_objects_held]->inc_ref();
1252 ++n_objects_held;
1253 }
1254 }
1255
1256 bool use_depends = true;
1257 sycl::event host_task_ev;
1258
1259 if (n_usm_owners_held > 0) {
1260 host_task_ev = q.submit([&](sycl::handler &cgh) {
1261 if (use_depends) {
1262 cgh.depends_on(depends);
1263 use_depends = false;
1264 }
1265 else {
1266 cgh.depends_on(host_task_ev);
1267 }
1268 cgh.host_task([shp_usm = std::move(shp_usm)]() {
1269 // no body, but shared pointers are captured in
1270 // the lambda, ensuring that USM allocation is
1271 // kept alive
1272 });
1273 });
1274 }
1275
1276 if (n_objects_held > 0) {
1277 host_task_ev = q.submit([&](sycl::handler &cgh) {
1278 if (use_depends) {
1279 cgh.depends_on(depends);
1280 use_depends = false;
1281 }
1282 else {
1283 cgh.depends_on(host_task_ev);
1284 }
1285 cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1286 py::gil_scoped_acquire acquire;
1287
1288 for (std::size_t i = 0; i < n_objects_held; ++i) {
1289 shp_arr[i]->dec_ref();
1290 }
1291 });
1292 });
1293 }
1294
1295 return host_task_ev;
1296}
1297
1300template <std::size_t num>
1301bool queues_are_compatible(const sycl::queue &exec_q,
1302 const sycl::queue (&alloc_qs)[num])
1303{
1304 for (std::size_t i = 0; i < num; ++i) {
1305
1306 if (exec_q != alloc_qs[i]) {
1307 return false;
1308 }
1309 }
1310 return true;
1311}
1312
1315template <std::size_t num>
1316bool queues_are_compatible(const sycl::queue &exec_q,
1317 const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1318{
1319 for (std::size_t i = 0; i < num; ++i) {
1320
1321 if (exec_q != arrs[i].get_queue()) {
1322 return false;
1323 }
1324 }
1325 return true;
1326}
1327} // end namespace utils
1328} // end namespace dpctl
usm_memory(void *usm_ptr, std::size_t nbytes, const sycl::queue &q, std::shared_ptr< void > shptr)
Create usm_memory object from shared pointer that manages lifetime of the USM allocation.
py::object get_usm_data() const
Get usm_data property of array.