78 PyTypeObject *PyUSMArrayType_;
80 int USM_ARRAY_C_CONTIGUOUS_;
81 int USM_ARRAY_F_CONTIGUOUS_;
82 int USM_ARRAY_WRITABLE_;
83 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
84 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
85 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
87 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
88 UAR_INT64_, UAR_UINT64_;
90 ~dpnp_capi() { default_usm_ndarray_.reset(); };
98 py::object default_usm_ndarray_pyobj() {
return *default_usm_ndarray_; }
103 void operator()(py::object *p)
const
105 const bool initialized = Py_IsInitialized();
106#if PY_VERSION_HEX < 0x30d0000
107 const bool finalizing = _Py_IsFinalizing();
109 const bool finalizing = Py_IsFinalizing();
111 const bool guard = initialized && !finalizing;
119 std::shared_ptr<py::object> default_usm_ndarray_;
122 : PyUSMArrayType_(
nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
123 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
124 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
125 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
126 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
127 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
128 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
129 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
130 UAR_INT64_(-1), UAR_UINT64_(-1), default_usm_ndarray_{}
134 import_dpnp__tensor___usmarray();
136 this->PyUSMArrayType_ = &PyUSMArrayType;
139 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
140 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
141 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
142 this->UAR_BOOL_ = UAR_BOOL;
143 this->UAR_BYTE_ = UAR_BYTE;
144 this->UAR_UBYTE_ = UAR_UBYTE;
145 this->UAR_SHORT_ = UAR_SHORT;
146 this->UAR_USHORT_ = UAR_USHORT;
147 this->UAR_INT_ = UAR_INT;
148 this->UAR_UINT_ = UAR_UINT;
149 this->UAR_LONG_ = UAR_LONG;
150 this->UAR_ULONG_ = UAR_ULONG;
151 this->UAR_LONGLONG_ = UAR_LONGLONG;
152 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
153 this->UAR_FLOAT_ = UAR_FLOAT;
154 this->UAR_DOUBLE_ = UAR_DOUBLE;
155 this->UAR_CFLOAT_ = UAR_CFLOAT;
156 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
157 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
158 this->UAR_HALF_ = UAR_HALF;
161 this->UAR_INT8_ = UAR_BYTE;
162 this->UAR_UINT8_ = UAR_UBYTE;
163 this->UAR_INT16_ = UAR_SHORT;
164 this->UAR_UINT16_ = UAR_USHORT;
166 platform_typeid_lookup<std::int32_t, long, int, short>(
167 UAR_LONG, UAR_INT, UAR_SHORT);
169 platform_typeid_lookup<std::uint32_t,
unsigned long,
unsigned int,
170 unsigned short>(UAR_ULONG, UAR_UINT,
173 platform_typeid_lookup<std::int64_t, long, long long, int>(
174 UAR_LONG, UAR_LONGLONG, UAR_INT);
176 platform_typeid_lookup<std::uint64_t,
unsigned long,
177 unsigned long long,
unsigned int>(
178 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
180 py::object py_default_usm_memory =
181 ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj();
183 py::module_ mod_usmarray = py::module_::import(
"dpnp.tensor._usmarray");
184 auto tensor_kl = mod_usmarray.attr(
"usm_ndarray");
186 const py::object &py_default_usm_ndarray =
187 tensor_kl(py::tuple(), py::arg(
"dtype") = py::str(
"u1"),
188 py::arg(
"buffer") = py_default_usm_memory);
190 default_usm_ndarray_ = std::shared_ptr<py::object>(
191 new py::object{py_default_usm_ndarray}, Deleter{});
256 PYBIND11_OBJECT(
usm_ndarray, py::object, [](PyObject *o) ->
bool {
257 return PyObject_TypeCheck(
258 o, detail::dpnp_capi::get().PyUSMArrayType_) != 0;
262 : py::object(detail::dpnp_capi::get().default_usm_ndarray_pyobj(),
266 throw py::error_already_set();
269 char *get_data()
const
271 PyUSMArrayObject *raw_ar = usm_array_ptr();
272 return raw_ar->data_;
275 template <
typename T>
278 return reinterpret_cast<T *
>(get_data());
283 PyUSMArrayObject *raw_ar = usm_array_ptr();
287 const py::ssize_t *get_shape_raw()
const
289 PyUSMArrayObject *raw_ar = usm_array_ptr();
290 return raw_ar->shape_;
293 std::vector<py::ssize_t> get_shape_vector()
const
295 auto raw_sh = get_shape_raw();
296 auto nd = get_ndim();
298 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
302 py::ssize_t get_shape(
int i)
const
304 auto shape_ptr = get_shape_raw();
308 const py::ssize_t *get_strides_raw()
const
310 PyUSMArrayObject *raw_ar = usm_array_ptr();
311 return raw_ar->strides_;
314 std::vector<py::ssize_t> get_strides_vector()
const
316 auto raw_st = get_strides_raw();
317 auto nd = get_ndim();
319 if (raw_st ==
nullptr) {
320 auto is_c_contig = is_c_contiguous();
321 auto is_f_contig = is_f_contiguous();
322 auto raw_sh = get_shape_raw();
324 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
325 return contig_strides;
327 else if (is_f_contig) {
328 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
329 return contig_strides;
332 throw std::runtime_error(
"Invalid array encountered when "
337 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
342 py::ssize_t get_size()
const
344 PyUSMArrayObject *raw_ar = usm_array_ptr();
346 int ndim = raw_ar->nd_;
347 const py::ssize_t *shape = raw_ar->shape_;
349 py::ssize_t nelems = 1;
350 for (
int i = 0; i < ndim; ++i) {
358 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets()
const
360 PyUSMArrayObject *raw_ar = usm_array_ptr();
362 int nd = raw_ar->nd_;
363 const py::ssize_t *shape = raw_ar->shape_;
364 const py::ssize_t *strides = raw_ar->strides_;
366 py::ssize_t offset_min = 0;
367 py::ssize_t offset_max = 0;
368 if (strides ==
nullptr) {
369 py::ssize_t stride(1);
370 for (
int i = 0; i < nd; ++i) {
371 offset_max += stride * (shape[i] - 1);
376 for (
int i = 0; i < nd; ++i) {
377 py::ssize_t delta = strides[i] * (shape[i] - 1);
378 if (strides[i] > 0) {
386 return std::make_pair(offset_min, offset_max);
389 sycl::queue get_queue()
const
391 PyUSMArrayObject *raw_ar = usm_array_ptr();
392 Py_MemoryObject *mem_obj =
393 reinterpret_cast<Py_MemoryObject *
>(raw_ar->base_);
395 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
396 DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
397 return *(
reinterpret_cast<sycl::queue *
>(QRef));
400 sycl::device get_device()
const
402 PyUSMArrayObject *raw_ar = usm_array_ptr();
403 Py_MemoryObject *mem_obj =
404 reinterpret_cast<Py_MemoryObject *
>(raw_ar->base_);
406 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
407 DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
408 return reinterpret_cast<sycl::queue *
>(QRef)->get_device();
411 int get_typenum()
const
413 PyUSMArrayObject *raw_ar = usm_array_ptr();
414 return raw_ar->typenum_;
417 int get_flags()
const
419 PyUSMArrayObject *raw_ar = usm_array_ptr();
420 return raw_ar->flags_;
423 int get_elemsize()
const
425 int typenum = get_typenum();
426 auto const &api = detail::dpnp_capi::get();
429 if (typenum == api.UAR_BOOL_)
431 if (typenum == api.UAR_BYTE_)
433 if (typenum == api.UAR_UBYTE_)
435 if (typenum == api.UAR_SHORT_)
437 if (typenum == api.UAR_USHORT_)
439 if (typenum == api.UAR_INT_)
441 if (typenum == api.UAR_UINT_)
443 if (typenum == api.UAR_LONG_)
445 if (typenum == api.UAR_ULONG_)
446 return sizeof(
unsigned long);
447 if (typenum == api.UAR_LONGLONG_)
449 if (typenum == api.UAR_ULONGLONG_)
451 if (typenum == api.UAR_FLOAT_)
453 if (typenum == api.UAR_DOUBLE_)
455 if (typenum == api.UAR_CFLOAT_)
457 if (typenum == api.UAR_CDOUBLE_)
459 if (typenum == api.UAR_HALF_)
465 bool is_c_contiguous()
const
467 int flags = get_flags();
468 auto const &api = detail::dpnp_capi::get();
469 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
472 bool is_f_contiguous()
const
474 int flags = get_flags();
475 auto const &api = detail::dpnp_capi::get();
476 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
479 bool is_writable()
const
481 int flags = get_flags();
482 auto const &api = detail::dpnp_capi::get();
483 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
489 PyUSMArrayObject *raw_ar = usm_array_ptr();
491 PyObject *usm_data = raw_ar->base_;
492 Py_XINCREF(usm_data);
495 return py::reinterpret_steal<py::object>(usm_data);
498 bool is_managed_by_smart_ptr()
const
500 PyUSMArrayObject *raw_ar = usm_array_ptr();
501 PyObject *usm_data = raw_ar->base_;
503 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
504 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
508 Py_MemoryObject *mem_obj =
509 reinterpret_cast<Py_MemoryObject *
>(usm_data);
510 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
512 return bool(opaque_ptr);
515 const std::shared_ptr<void> &get_smart_ptr_owner()
const
517 PyUSMArrayObject *raw_ar = usm_array_ptr();
518 PyObject *usm_data = raw_ar->base_;
520 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
522 if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
523 throw std::runtime_error(
524 "usm_ndarray object does not have Memory object "
525 "managing lifetime of USM allocation");
528 Py_MemoryObject *mem_obj =
529 reinterpret_cast<Py_MemoryObject *
>(usm_data);
530 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
534 reinterpret_cast<std::shared_ptr<void> *
>(opaque_ptr);
538 throw std::runtime_error(
539 "Memory object underlying usm_ndarray does not have "
540 "smart pointer managing lifetime of USM allocation");
545 PyUSMArrayObject *usm_array_ptr()
const
547 return reinterpret_cast<PyUSMArrayObject *
>(m_ptr);