78 PyTypeObject *PyUSMArrayType_;
80 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
81 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
82 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
83 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
84 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
85 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
86 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
87 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
88 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
89 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
90 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
91 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
97 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
102 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
111 int USM_ARRAY_C_CONTIGUOUS_;
112 int USM_ARRAY_F_CONTIGUOUS_;
113 int USM_ARRAY_WRITABLE_;
114 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
115 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
116 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
118 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
119 UAR_INT64_, UAR_UINT64_;
121 ~dpnp_capi() { default_usm_ndarray_.reset(); };
129 py::object default_usm_ndarray_pyobj() {
return *default_usm_ndarray_; }
134 void operator()(py::object *p)
const
136 const bool initialized = Py_IsInitialized();
137#if PY_VERSION_HEX < 0x30d0000
138 const bool finalizing = _Py_IsFinalizing();
140 const bool finalizing = Py_IsFinalizing();
142 const bool guard = initialized && !finalizing;
150 std::shared_ptr<py::object> default_usm_ndarray_;
153 : PyUSMArrayType_(
nullptr), UsmNDArray_GetData_(
nullptr),
154 UsmNDArray_GetNDim_(
nullptr), UsmNDArray_GetShape_(
nullptr),
155 UsmNDArray_GetStrides_(
nullptr), UsmNDArray_GetTypenum_(
nullptr),
156 UsmNDArray_GetElementSize_(
nullptr), UsmNDArray_GetFlags_(
nullptr),
157 UsmNDArray_GetQueueRef_(
nullptr), UsmNDArray_GetOffset_(
nullptr),
158 UsmNDArray_GetUSMData_(
nullptr), UsmNDArray_SetWritableFlag_(
nullptr),
159 UsmNDArray_MakeSimpleFromMemory_(
nullptr),
160 UsmNDArray_MakeSimpleFromPtr_(
nullptr),
161 UsmNDArray_MakeFromPtr_(
nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
162 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
163 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
164 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
165 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
166 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
167 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
168 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
169 UAR_INT64_(-1), UAR_UINT64_(-1), default_usm_ndarray_{}
173 import_dpnp__tensor___usmarray();
175 this->PyUSMArrayType_ = &PyUSMArrayType;
178 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
179 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
180 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
181 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
182 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
183 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
184 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
185 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
186 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
187 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
188 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
189 this->UsmNDArray_MakeSimpleFromMemory_ =
190 UsmNDArray_MakeSimpleFromMemory;
191 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
192 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
195 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
196 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
197 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
198 this->UAR_BOOL_ = UAR_BOOL;
199 this->UAR_BYTE_ = UAR_BYTE;
200 this->UAR_UBYTE_ = UAR_UBYTE;
201 this->UAR_SHORT_ = UAR_SHORT;
202 this->UAR_USHORT_ = UAR_USHORT;
203 this->UAR_INT_ = UAR_INT;
204 this->UAR_UINT_ = UAR_UINT;
205 this->UAR_LONG_ = UAR_LONG;
206 this->UAR_ULONG_ = UAR_ULONG;
207 this->UAR_LONGLONG_ = UAR_LONGLONG;
208 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
209 this->UAR_FLOAT_ = UAR_FLOAT;
210 this->UAR_DOUBLE_ = UAR_DOUBLE;
211 this->UAR_CFLOAT_ = UAR_CFLOAT;
212 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
213 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
214 this->UAR_HALF_ = UAR_HALF;
217 this->UAR_INT8_ = UAR_BYTE;
218 this->UAR_UINT8_ = UAR_UBYTE;
219 this->UAR_INT16_ = UAR_SHORT;
220 this->UAR_UINT16_ = UAR_USHORT;
222 platform_typeid_lookup<std::int32_t, long, int, short>(
223 UAR_LONG, UAR_INT, UAR_SHORT);
225 platform_typeid_lookup<std::uint32_t,
unsigned long,
unsigned int,
226 unsigned short>(UAR_ULONG, UAR_UINT,
229 platform_typeid_lookup<std::int64_t, long, long long, int>(
230 UAR_LONG, UAR_LONGLONG, UAR_INT);
232 platform_typeid_lookup<std::uint64_t,
unsigned long,
233 unsigned long long,
unsigned int>(
234 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
236 py::object py_default_usm_memory =
237 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj();
239 py::module_ mod_usmarray = py::module_::import(
"dpnp.tensor._usmarray");
240 auto tensor_kl = mod_usmarray.attr(
"usm_ndarray");
242 const py::object &py_default_usm_ndarray =
243 tensor_kl(py::tuple(), py::arg(
"dtype") = py::str(
"u1"),
244 py::arg(
"buffer") = py_default_usm_memory);
246 default_usm_ndarray_ = std::shared_ptr<py::object>(
247 new py::object{py_default_usm_ndarray}, Deleter{});
312 PYBIND11_OBJECT(
usm_ndarray, py::object, [](PyObject *o) ->
bool {
313 return PyObject_TypeCheck(
314 o, detail::dpnp_capi::get().PyUSMArrayType_) != 0;
318 : py::object(detail::dpnp_capi::get().default_usm_ndarray_pyobj(),
322 throw py::error_already_set();
325 char *get_data()
const
327 PyUSMArrayObject *raw_ar = usm_array_ptr();
329 auto const &api = detail::dpnp_capi::get();
330 return api.UsmNDArray_GetData_(raw_ar);
333 template <
typename T>
336 return reinterpret_cast<T *
>(get_data());
341 PyUSMArrayObject *raw_ar = usm_array_ptr();
343 auto const &api = detail::dpnp_capi::get();
344 return api.UsmNDArray_GetNDim_(raw_ar);
347 const py::ssize_t *get_shape_raw()
const
349 PyUSMArrayObject *raw_ar = usm_array_ptr();
351 auto const &api = detail::dpnp_capi::get();
352 return api.UsmNDArray_GetShape_(raw_ar);
355 std::vector<py::ssize_t> get_shape_vector()
const
357 auto raw_sh = get_shape_raw();
358 auto nd = get_ndim();
360 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
364 py::ssize_t get_shape(
int i)
const
366 auto shape_ptr = get_shape_raw();
370 const py::ssize_t *get_strides_raw()
const
372 PyUSMArrayObject *raw_ar = usm_array_ptr();
374 auto const &api = detail::dpnp_capi::get();
375 return api.UsmNDArray_GetStrides_(raw_ar);
378 std::vector<py::ssize_t> get_strides_vector()
const
380 auto raw_st = get_strides_raw();
381 auto nd = get_ndim();
383 if (raw_st ==
nullptr) {
384 auto is_c_contig = is_c_contiguous();
385 auto is_f_contig = is_f_contiguous();
386 auto raw_sh = get_shape_raw();
388 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
389 return contig_strides;
391 else if (is_f_contig) {
392 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
393 return contig_strides;
396 throw std::runtime_error(
"Invalid array encountered when "
401 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
406 py::ssize_t get_size()
const
408 PyUSMArrayObject *raw_ar = usm_array_ptr();
410 auto const &api = detail::dpnp_capi::get();
411 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
412 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
414 py::ssize_t nelems = 1;
415 for (
int i = 0; i < ndim; ++i) {
423 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets()
const
425 PyUSMArrayObject *raw_ar = usm_array_ptr();
427 auto const &api = detail::dpnp_capi::get();
428 int nd = api.UsmNDArray_GetNDim_(raw_ar);
429 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
430 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
432 py::ssize_t offset_min = 0;
433 py::ssize_t offset_max = 0;
434 if (strides ==
nullptr) {
435 py::ssize_t stride(1);
436 for (
int i = 0; i < nd; ++i) {
437 offset_max += stride * (shape[i] - 1);
442 for (
int i = 0; i < nd; ++i) {
443 py::ssize_t delta = strides[i] * (shape[i] - 1);
444 if (strides[i] > 0) {
452 return std::make_pair(offset_min, offset_max);
455 sycl::queue get_queue()
const
457 PyUSMArrayObject *raw_ar = usm_array_ptr();
459 auto const &api = detail::dpnp_capi::get();
460 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
461 return *(
reinterpret_cast<sycl::queue *
>(QRef));
464 sycl::device get_device()
const
466 PyUSMArrayObject *raw_ar = usm_array_ptr();
468 auto const &api = detail::dpnp_capi::get();
469 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
470 return reinterpret_cast<sycl::queue *
>(QRef)->get_device();
473 int get_typenum()
const
475 PyUSMArrayObject *raw_ar = usm_array_ptr();
477 auto const &api = detail::dpnp_capi::get();
478 return api.UsmNDArray_GetTypenum_(raw_ar);
481 int get_flags()
const
483 PyUSMArrayObject *raw_ar = usm_array_ptr();
485 auto const &api = detail::dpnp_capi::get();
486 return api.UsmNDArray_GetFlags_(raw_ar);
489 int get_elemsize()
const
491 PyUSMArrayObject *raw_ar = usm_array_ptr();
493 auto const &api = detail::dpnp_capi::get();
494 return api.UsmNDArray_GetElementSize_(raw_ar);
497 bool is_c_contiguous()
const
499 int flags = get_flags();
500 auto const &api = detail::dpnp_capi::get();
501 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
504 bool is_f_contiguous()
const
506 int flags = get_flags();
507 auto const &api = detail::dpnp_capi::get();
508 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
511 bool is_writable()
const
513 int flags = get_flags();
514 auto const &api = detail::dpnp_capi::get();
515 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
521 PyUSMArrayObject *raw_ar = usm_array_ptr();
523 auto const &api = detail::dpnp_capi::get();
525 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
528 return py::reinterpret_steal<py::object>(usm_data);
531 bool is_managed_by_smart_ptr()
const
533 PyUSMArrayObject *raw_ar = usm_array_ptr();
535 auto const &api = detail::dpnp_capi::get();
536 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
538 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
539 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
544 Py_MemoryObject *mem_obj =
545 reinterpret_cast<Py_MemoryObject *
>(usm_data);
546 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
549 return bool(opaque_ptr);
552 const std::shared_ptr<void> &get_smart_ptr_owner()
const
554 PyUSMArrayObject *raw_ar = usm_array_ptr();
556 auto const &api = detail::dpnp_capi::get();
557 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
559 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
560 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
562 throw std::runtime_error(
563 "usm_ndarray object does not have Memory object "
564 "managing lifetime of USM allocation");
567 Py_MemoryObject *mem_obj =
568 reinterpret_cast<Py_MemoryObject *
>(usm_data);
569 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
574 reinterpret_cast<std::shared_ptr<void> *
>(opaque_ptr);
578 throw std::runtime_error(
579 "Memory object underlying usm_ndarray does not have "
580 "smart pointer managing lifetime of USM allocation");
585 PyUSMArrayObject *usm_array_ptr()
const
587 return reinterpret_cast<PyUSMArrayObject *
>(m_ptr);