79 PyTypeObject *PyUSMArrayType_;
81 char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
82 int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
83 py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
84 py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
85 int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
86 int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
87 int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
88 DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
89 py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
90 PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
91 void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
92 PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
98 PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
103 PyObject *(*UsmNDArray_MakeFromPtr_)(int,
112 int USM_ARRAY_C_CONTIGUOUS_;
113 int USM_ARRAY_F_CONTIGUOUS_;
114 int USM_ARRAY_WRITABLE_;
115 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
116 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
117 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
119 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
120 UAR_INT64_, UAR_UINT64_;
122 ~dpnp_capi() { default_usm_ndarray_.reset(); };
130 py::object default_usm_ndarray_pyobj() {
return *default_usm_ndarray_; }
135 void operator()(py::object *p)
const
137 const bool initialized = Py_IsInitialized();
138#if PY_VERSION_HEX < 0x30d0000
139 const bool finalizing = _Py_IsFinalizing();
141 const bool finalizing = Py_IsFinalizing();
143 const bool guard = initialized && !finalizing;
151 std::shared_ptr<py::object> default_usm_ndarray_;
154 : PyUSMArrayType_(
nullptr), UsmNDArray_GetData_(
nullptr),
155 UsmNDArray_GetNDim_(
nullptr), UsmNDArray_GetShape_(
nullptr),
156 UsmNDArray_GetStrides_(
nullptr), UsmNDArray_GetTypenum_(
nullptr),
157 UsmNDArray_GetElementSize_(
nullptr), UsmNDArray_GetFlags_(
nullptr),
158 UsmNDArray_GetQueueRef_(
nullptr), UsmNDArray_GetOffset_(
nullptr),
159 UsmNDArray_GetUSMData_(
nullptr), UsmNDArray_SetWritableFlag_(
nullptr),
160 UsmNDArray_MakeSimpleFromMemory_(
nullptr),
161 UsmNDArray_MakeSimpleFromPtr_(
nullptr),
162 UsmNDArray_MakeFromPtr_(
nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
163 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
164 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
165 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
166 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
167 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
168 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
169 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
170 UAR_INT64_(-1), UAR_UINT64_(-1), default_usm_ndarray_{}
174 import_dpnp__tensor___usmarray();
176 this->PyUSMArrayType_ = &PyUSMArrayType;
179 this->UsmNDArray_GetData_ = UsmNDArray_GetData;
180 this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
181 this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
182 this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
183 this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
184 this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
185 this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
186 this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
187 this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
188 this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
189 this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
190 this->UsmNDArray_MakeSimpleFromMemory_ =
191 UsmNDArray_MakeSimpleFromMemory;
192 this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
193 this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
196 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS_VALUE;
197 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS_VALUE;
198 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE_VALUE;
199 this->UAR_BOOL_ = UAR_BOOL_VALUE;
200 this->UAR_BYTE_ = UAR_BYTE_VALUE;
201 this->UAR_UBYTE_ = UAR_UBYTE_VALUE;
202 this->UAR_SHORT_ = UAR_SHORT_VALUE;
203 this->UAR_USHORT_ = UAR_USHORT_VALUE;
204 this->UAR_INT_ = UAR_INT_VALUE;
205 this->UAR_UINT_ = UAR_UINT_VALUE;
206 this->UAR_LONG_ = UAR_LONG_VALUE;
207 this->UAR_ULONG_ = UAR_ULONG_VALUE;
208 this->UAR_LONGLONG_ = UAR_LONGLONG_VALUE;
209 this->UAR_ULONGLONG_ = UAR_ULONGLONG_VALUE;
210 this->UAR_FLOAT_ = UAR_FLOAT_VALUE;
211 this->UAR_DOUBLE_ = UAR_DOUBLE_VALUE;
212 this->UAR_CFLOAT_ = UAR_CFLOAT_VALUE;
213 this->UAR_CDOUBLE_ = UAR_CDOUBLE_VALUE;
214 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL_VALUE;
215 this->UAR_HALF_ = UAR_HALF_VALUE;
218 this->UAR_INT8_ = UAR_BYTE_VALUE;
219 this->UAR_UINT8_ = UAR_UBYTE_VALUE;
220 this->UAR_INT16_ = UAR_SHORT_VALUE;
221 this->UAR_UINT16_ = UAR_USHORT_VALUE;
223 platform_typeid_lookup<std::int32_t, long, int, short>(
224 UAR_LONG_VALUE, UAR_INT_VALUE, UAR_SHORT_VALUE);
226 platform_typeid_lookup<std::uint32_t,
unsigned long,
unsigned int,
228 UAR_ULONG_VALUE, UAR_UINT_VALUE, UAR_USHORT_VALUE);
230 platform_typeid_lookup<std::int64_t, long, long long, int>(
231 UAR_LONG_VALUE, UAR_LONGLONG_VALUE, UAR_INT_VALUE);
233 platform_typeid_lookup<std::uint64_t,
unsigned long,
234 unsigned long long,
unsigned int>(
235 UAR_ULONG_VALUE, UAR_ULONGLONG_VALUE, UAR_UINT_VALUE);
237 py::object py_default_usm_memory =
238 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj();
240 py::module_ mod_usmarray = py::module_::import(
"dpnp.tensor._usmarray");
241 auto tensor_kl = mod_usmarray.attr(
"usm_ndarray");
243 const py::object &py_default_usm_ndarray =
244 tensor_kl(py::tuple(), py::arg(
"dtype") = py::str(
"u1"),
245 py::arg(
"buffer") = py_default_usm_memory);
247 default_usm_ndarray_ = std::shared_ptr<py::object>(
248 new py::object{py_default_usm_ndarray}, Deleter{});
313 PYBIND11_OBJECT(
usm_ndarray, py::object, [](PyObject *o) ->
bool {
314 return PyObject_TypeCheck(
315 o, detail::dpnp_capi::get().PyUSMArrayType_) != 0;
319 : py::object(detail::dpnp_capi::get().default_usm_ndarray_pyobj(),
323 throw py::error_already_set();
326 char *get_data()
const
328 PyUSMArrayObject *raw_ar = usm_array_ptr();
330 auto const &api = detail::dpnp_capi::get();
331 return api.UsmNDArray_GetData_(raw_ar);
334 template <
typename T>
337 return reinterpret_cast<T *
>(get_data());
342 PyUSMArrayObject *raw_ar = usm_array_ptr();
344 auto const &api = detail::dpnp_capi::get();
345 return api.UsmNDArray_GetNDim_(raw_ar);
348 const py::ssize_t *get_shape_raw()
const
350 PyUSMArrayObject *raw_ar = usm_array_ptr();
352 auto const &api = detail::dpnp_capi::get();
353 return api.UsmNDArray_GetShape_(raw_ar);
356 std::vector<py::ssize_t> get_shape_vector()
const
358 auto raw_sh = get_shape_raw();
359 auto nd = get_ndim();
361 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
365 py::ssize_t get_shape(
int i)
const
367 auto shape_ptr = get_shape_raw();
371 const py::ssize_t *get_strides_raw()
const
373 PyUSMArrayObject *raw_ar = usm_array_ptr();
375 auto const &api = detail::dpnp_capi::get();
376 return api.UsmNDArray_GetStrides_(raw_ar);
379 std::vector<py::ssize_t> get_strides_vector()
const
381 auto raw_st = get_strides_raw();
382 auto nd = get_ndim();
384 if (raw_st ==
nullptr) {
385 auto is_c_contig = is_c_contiguous();
386 auto is_f_contig = is_f_contiguous();
387 auto raw_sh = get_shape_raw();
389 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
390 return contig_strides;
392 else if (is_f_contig) {
393 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
394 return contig_strides;
397 throw std::runtime_error(
"Invalid array encountered when "
402 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
407 py::ssize_t get_size()
const
409 PyUSMArrayObject *raw_ar = usm_array_ptr();
411 auto const &api = detail::dpnp_capi::get();
412 int ndim = api.UsmNDArray_GetNDim_(raw_ar);
413 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
415 py::ssize_t nelems = 1;
416 for (
int i = 0; i < ndim; ++i) {
424 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets()
const
426 PyUSMArrayObject *raw_ar = usm_array_ptr();
428 auto const &api = detail::dpnp_capi::get();
429 int nd = api.UsmNDArray_GetNDim_(raw_ar);
430 const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
431 const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
433 py::ssize_t offset_min = 0;
434 py::ssize_t offset_max = 0;
435 if (strides ==
nullptr) {
436 py::ssize_t stride(1);
437 for (
int i = 0; i < nd; ++i) {
438 offset_max += stride * (shape[i] - 1);
443 for (
int i = 0; i < nd; ++i) {
444 py::ssize_t delta = strides[i] * (shape[i] - 1);
445 if (strides[i] > 0) {
453 return std::make_pair(offset_min, offset_max);
456 sycl::queue get_queue()
const
458 PyUSMArrayObject *raw_ar = usm_array_ptr();
460 auto const &api = detail::dpnp_capi::get();
461 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
462 return *(
reinterpret_cast<sycl::queue *
>(QRef));
465 sycl::device get_device()
const
467 PyUSMArrayObject *raw_ar = usm_array_ptr();
469 auto const &api = detail::dpnp_capi::get();
470 DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
471 return reinterpret_cast<sycl::queue *
>(QRef)->get_device();
474 int get_typenum()
const
476 PyUSMArrayObject *raw_ar = usm_array_ptr();
478 auto const &api = detail::dpnp_capi::get();
479 return api.UsmNDArray_GetTypenum_(raw_ar);
482 int get_flags()
const
484 PyUSMArrayObject *raw_ar = usm_array_ptr();
486 auto const &api = detail::dpnp_capi::get();
487 return api.UsmNDArray_GetFlags_(raw_ar);
490 int get_elemsize()
const
492 PyUSMArrayObject *raw_ar = usm_array_ptr();
494 auto const &api = detail::dpnp_capi::get();
495 return api.UsmNDArray_GetElementSize_(raw_ar);
498 bool is_c_contiguous()
const
500 int flags = get_flags();
501 auto const &api = detail::dpnp_capi::get();
502 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
505 bool is_f_contiguous()
const
507 int flags = get_flags();
508 auto const &api = detail::dpnp_capi::get();
509 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
512 bool is_writable()
const
514 int flags = get_flags();
515 auto const &api = detail::dpnp_capi::get();
516 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
522 PyUSMArrayObject *raw_ar = usm_array_ptr();
524 auto const &api = detail::dpnp_capi::get();
526 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
529 return py::reinterpret_steal<py::object>(usm_data);
532 bool is_managed_by_smart_ptr()
const
534 PyUSMArrayObject *raw_ar = usm_array_ptr();
536 auto const &api = detail::dpnp_capi::get();
537 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
539 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
540 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
545 Py_MemoryObject *mem_obj =
546 reinterpret_cast<Py_MemoryObject *
>(usm_data);
547 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
550 return bool(opaque_ptr);
553 const std::shared_ptr<void> &get_smart_ptr_owner()
const
555 PyUSMArrayObject *raw_ar = usm_array_ptr();
557 auto const &api = detail::dpnp_capi::get();
558 PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
560 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
561 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
563 throw std::runtime_error(
564 "usm_ndarray object does not have Memory object "
565 "managing lifetime of USM allocation");
568 Py_MemoryObject *mem_obj =
569 reinterpret_cast<Py_MemoryObject *
>(usm_data);
570 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
575 reinterpret_cast<std::shared_ptr<void> *
>(opaque_ptr);
579 throw std::runtime_error(
580 "Memory object underlying usm_ndarray does not have "
581 "smart pointer managing lifetime of USM allocation");
586 PyUSMArrayObject *usm_array_ptr()
const
588 return reinterpret_cast<PyUSMArrayObject *
>(m_ptr);