101 PyTypeObject *Py_SyclDeviceType_;
102 PyTypeObject *PySyclDeviceType_;
103 PyTypeObject *Py_SyclContextType_;
104 PyTypeObject *PySyclContextType_;
105 PyTypeObject *Py_SyclEventType_;
106 PyTypeObject *PySyclEventType_;
107 PyTypeObject *Py_SyclQueueType_;
108 PyTypeObject *PySyclQueueType_;
109 PyTypeObject *Py_MemoryType_;
110 PyTypeObject *PyMemoryUSMDeviceType_;
111 PyTypeObject *PyMemoryUSMSharedType_;
112 PyTypeObject *PyMemoryUSMHostType_;
113 PyTypeObject *PyUSMArrayType_;
114 PyTypeObject *PySyclProgramType_;
115 PyTypeObject *PySyclKernelType_;
117 DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
118 PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
120 DPCTLSyclContextRef (*SyclContext_GetContextRef_)(PySyclContextObject *);
121 PySyclContextObject *(*SyclContext_Make_)(DPCTLSyclContextRef);
123 DPCTLSyclEventRef (*SyclEvent_GetEventRef_)(PySyclEventObject *);
124 PySyclEventObject *(*SyclEvent_Make_)(DPCTLSyclEventRef);
126 DPCTLSyclQueueRef (*SyclQueue_GetQueueRef_)(PySyclQueueObject *);
127 PySyclQueueObject *(*SyclQueue_Make_)(DPCTLSyclQueueRef);
130 DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
131 void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
132 DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
133 DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
134 size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
135 PyObject *(*Memory_Make_)(DPCTLSyclUSMRef,
141 DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
142 PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef,
const char *);
144 DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(
145 PySyclProgramObject *);
146 PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
148 int USM_ARRAY_C_CONTIGUOUS_;
149 int USM_ARRAY_F_CONTIGUOUS_;
150 int USM_ARRAY_WRITABLE_;
151 int UAR_BOOL_, UAR_BYTE_, UAR_UBYTE_, UAR_SHORT_, UAR_USHORT_, UAR_INT_,
152 UAR_UINT_, UAR_LONG_, UAR_ULONG_, UAR_LONGLONG_, UAR_ULONGLONG_,
153 UAR_FLOAT_, UAR_DOUBLE_, UAR_CFLOAT_, UAR_CDOUBLE_, UAR_TYPE_SENTINEL_,
155 int UAR_INT8_, UAR_UINT8_, UAR_INT16_, UAR_UINT16_, UAR_INT32_, UAR_UINT32_,
156 UAR_INT64_, UAR_UINT64_;
158 bool PySyclDevice_Check_(PyObject *obj)
const
160 return PyObject_TypeCheck(obj, PySyclDeviceType_) != 0;
162 bool PySyclContext_Check_(PyObject *obj)
const
164 return PyObject_TypeCheck(obj, PySyclContextType_) != 0;
166 bool PySyclEvent_Check_(PyObject *obj)
const
168 return PyObject_TypeCheck(obj, PySyclEventType_) != 0;
170 bool PySyclQueue_Check_(PyObject *obj)
const
172 return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
174 bool PySyclKernel_Check_(PyObject *obj)
const
176 return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
178 bool PySyclProgram_Check_(PyObject *obj)
const
180 return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
185 as_usm_memory_.reset();
186 default_usm_ndarray_.reset();
187 default_usm_memory_.reset();
188 default_sycl_queue_.reset();
197 py::object default_sycl_queue_pyobj()
199 return *default_sycl_queue_;
201 py::object default_usm_memory_pyobj()
203 return *default_usm_memory_;
205 py::object default_usm_ndarray_pyobj()
207 return *default_usm_ndarray_;
209 py::object as_usm_memory_pyobj()
211 return *as_usm_memory_;
217 void operator()(py::object *p)
const
219 const bool initialized = Py_IsInitialized();
220#if PY_VERSION_HEX < 0x30d0000
221 const bool finalizing = _Py_IsFinalizing();
223 const bool finalizing = Py_IsFinalizing();
225 const bool guard = initialized && !finalizing;
233 std::shared_ptr<py::object> default_sycl_queue_;
234 std::shared_ptr<py::object> default_usm_memory_;
235 std::shared_ptr<py::object> default_usm_ndarray_;
236 std::shared_ptr<py::object> as_usm_memory_;
239 : Py_SyclDeviceType_(
nullptr), PySyclDeviceType_(
nullptr),
240 Py_SyclContextType_(
nullptr), PySyclContextType_(
nullptr),
241 Py_SyclEventType_(
nullptr), PySyclEventType_(
nullptr),
242 Py_SyclQueueType_(
nullptr), PySyclQueueType_(
nullptr),
243 Py_MemoryType_(
nullptr), PyMemoryUSMDeviceType_(
nullptr),
244 PyMemoryUSMSharedType_(
nullptr), PyMemoryUSMHostType_(
nullptr),
245 PyUSMArrayType_(
nullptr), PySyclProgramType_(
nullptr),
246 PySyclKernelType_(
nullptr), SyclDevice_GetDeviceRef_(
nullptr),
247 SyclDevice_Make_(
nullptr), SyclContext_GetContextRef_(
nullptr),
248 SyclContext_Make_(
nullptr), SyclEvent_GetEventRef_(
nullptr),
249 SyclEvent_Make_(
nullptr), SyclQueue_GetQueueRef_(
nullptr),
250 SyclQueue_Make_(
nullptr), Memory_GetUsmPointer_(
nullptr),
251 Memory_GetOpaquePointer_(
nullptr), Memory_GetContextRef_(
nullptr),
252 Memory_GetQueueRef_(
nullptr), Memory_GetNumBytes_(
nullptr),
253 Memory_Make_(
nullptr), SyclKernel_GetKernelRef_(
nullptr),
254 SyclKernel_Make_(
nullptr), SyclProgram_GetKernelBundleRef_(
nullptr),
255 SyclProgram_Make_(
nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
256 USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
257 UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
258 UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
259 UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
260 UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
261 UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
262 UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
263 UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
264 default_usm_memory_{}, default_usm_ndarray_{}, as_usm_memory_{}
269 import_dpctl___sycl_device();
270 import_dpctl___sycl_context();
271 import_dpctl___sycl_event();
272 import_dpctl___sycl_queue();
273 import_dpctl__memory___memory();
274 import_dpctl__program___program();
276 import_dpctl_ext__tensor___usmarray();
279 this->Py_SyclDeviceType_ = &Py_SyclDeviceType;
280 this->PySyclDeviceType_ = &PySyclDeviceType;
281 this->Py_SyclContextType_ = &Py_SyclContextType;
282 this->PySyclContextType_ = &PySyclContextType;
283 this->Py_SyclEventType_ = &Py_SyclEventType;
284 this->PySyclEventType_ = &PySyclEventType;
285 this->Py_SyclQueueType_ = &Py_SyclQueueType;
286 this->PySyclQueueType_ = &PySyclQueueType;
287 this->Py_MemoryType_ = &Py_MemoryType;
288 this->PyMemoryUSMDeviceType_ = &PyMemoryUSMDeviceType;
289 this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
290 this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
291 this->PyUSMArrayType_ = &PyUSMArrayType;
292 this->PySyclProgramType_ = &PySyclProgramType;
293 this->PySyclKernelType_ = &PySyclKernelType;
296 this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
297 this->SyclDevice_Make_ = SyclDevice_Make;
300 this->SyclContext_GetContextRef_ = SyclContext_GetContextRef;
301 this->SyclContext_Make_ = SyclContext_Make;
304 this->SyclEvent_GetEventRef_ = SyclEvent_GetEventRef;
305 this->SyclEvent_Make_ = SyclEvent_Make;
308 this->SyclQueue_GetQueueRef_ = SyclQueue_GetQueueRef;
309 this->SyclQueue_Make_ = SyclQueue_Make;
312 this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
313 this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
314 this->Memory_GetContextRef_ = Memory_GetContextRef;
315 this->Memory_GetQueueRef_ = Memory_GetQueueRef;
316 this->Memory_GetNumBytes_ = Memory_GetNumBytes;
317 this->Memory_Make_ = Memory_Make;
320 this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
321 this->SyclKernel_Make_ = SyclKernel_Make;
322 this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
323 this->SyclProgram_Make_ = SyclProgram_Make;
326 this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
327 this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
328 this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE;
329 this->UAR_BOOL_ = UAR_BOOL;
330 this->UAR_BYTE_ = UAR_BYTE;
331 this->UAR_UBYTE_ = UAR_UBYTE;
332 this->UAR_SHORT_ = UAR_SHORT;
333 this->UAR_USHORT_ = UAR_USHORT;
334 this->UAR_INT_ = UAR_INT;
335 this->UAR_UINT_ = UAR_UINT;
336 this->UAR_LONG_ = UAR_LONG;
337 this->UAR_ULONG_ = UAR_ULONG;
338 this->UAR_LONGLONG_ = UAR_LONGLONG;
339 this->UAR_ULONGLONG_ = UAR_ULONGLONG;
340 this->UAR_FLOAT_ = UAR_FLOAT;
341 this->UAR_DOUBLE_ = UAR_DOUBLE;
342 this->UAR_CFLOAT_ = UAR_CFLOAT;
343 this->UAR_CDOUBLE_ = UAR_CDOUBLE;
344 this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL;
345 this->UAR_HALF_ = UAR_HALF;
348 this->UAR_INT8_ = UAR_BYTE;
349 this->UAR_UINT8_ = UAR_UBYTE;
350 this->UAR_INT16_ = UAR_SHORT;
351 this->UAR_UINT16_ = UAR_USHORT;
353 platform_typeid_lookup<std::int32_t, long, int, short>(
354 UAR_LONG, UAR_INT, UAR_SHORT);
356 platform_typeid_lookup<std::uint32_t,
unsigned long,
unsigned int,
357 unsigned short>(UAR_ULONG, UAR_UINT,
360 platform_typeid_lookup<std::int64_t, long, long long, int>(
361 UAR_LONG, UAR_LONGLONG, UAR_INT);
363 platform_typeid_lookup<std::uint64_t,
unsigned long,
364 unsigned long long,
unsigned int>(
365 UAR_ULONG, UAR_ULONGLONG, UAR_UINT);
370 PySyclQueueObject *py_q_tmp =
371 SyclQueue_Make(
reinterpret_cast<DPCTLSyclQueueRef
>(&q_));
372 const py::object &py_sycl_queue = py::reinterpret_steal<py::object>(
373 reinterpret_cast<PyObject *
>(py_q_tmp));
375 default_sycl_queue_ = std::shared_ptr<py::object>(
376 new py::object(py_sycl_queue), Deleter{});
378 py::module_ mod_memory = py::module_::import(
"dpctl.memory");
379 const py::object &py_as_usm_memory = mod_memory.attr(
"as_usm_memory");
380 as_usm_memory_ = std::shared_ptr<py::object>(
381 new py::object{py_as_usm_memory}, Deleter{});
383 auto mem_kl = mod_memory.attr(
"MemoryUSMHost");
384 const py::object &py_default_usm_memory =
385 mem_kl(1, py::arg(
"queue") = py_sycl_queue);
386 default_usm_memory_ = std::shared_ptr<py::object>(
387 new py::object{py_default_usm_memory}, Deleter{});
391 py::module_ mod_usmarray =
392 py::module_::import(
"dpctl_ext.tensor._usmarray");
393 auto tensor_kl = mod_usmarray.attr(
"usm_ndarray");
395 const py::object &py_default_usm_ndarray =
396 tensor_kl(py::tuple(), py::arg(
"dtype") = py::str(
"u1"),
397 py::arg(
"buffer") = py_default_usm_memory);
399 default_usm_ndarray_ = std::shared_ptr<py::object>(
400 new py::object{py_default_usm_ndarray}, Deleter{});
931 PYBIND11_OBJECT(
usm_ndarray, py::object, [](PyObject *o) ->
bool {
932 return PyObject_TypeCheck(
933 o, ::dpctl::detail::dpctl_capi::get().PyUSMArrayType_) != 0;
938 ::dpctl::detail::dpctl_capi::get().default_usm_ndarray_pyobj(),
942 throw py::error_already_set();
945 char *get_data()
const
947 PyUSMArrayObject *raw_ar = usm_array_ptr();
948 return raw_ar->data_;
951 template <
typename T>
954 return reinterpret_cast<T *
>(get_data());
959 PyUSMArrayObject *raw_ar = usm_array_ptr();
963 const py::ssize_t *get_shape_raw()
const
965 PyUSMArrayObject *raw_ar = usm_array_ptr();
966 return raw_ar->shape_;
969 std::vector<py::ssize_t> get_shape_vector()
const
971 auto raw_sh = get_shape_raw();
972 auto nd = get_ndim();
974 std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
978 py::ssize_t get_shape(
int i)
const
980 auto shape_ptr = get_shape_raw();
984 const py::ssize_t *get_strides_raw()
const
986 PyUSMArrayObject *raw_ar = usm_array_ptr();
987 return raw_ar->strides_;
990 std::vector<py::ssize_t> get_strides_vector()
const
992 auto raw_st = get_strides_raw();
993 auto nd = get_ndim();
995 if (raw_st ==
nullptr) {
996 auto is_c_contig = is_c_contiguous();
997 auto is_f_contig = is_f_contiguous();
998 auto raw_sh = get_shape_raw();
1000 const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
1001 return contig_strides;
1003 else if (is_f_contig) {
1004 const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
1005 return contig_strides;
1008 throw std::runtime_error(
"Invalid array encountered when "
1009 "building strides");
1013 std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
1018 py::ssize_t get_size()
const
1020 PyUSMArrayObject *raw_ar = usm_array_ptr();
1022 int ndim = raw_ar->nd_;
1023 const py::ssize_t *shape = raw_ar->shape_;
1025 py::ssize_t nelems = 1;
1026 for (
int i = 0; i < ndim; ++i) {
1030 assert(nelems >= 0);
1034 std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets()
const
1036 PyUSMArrayObject *raw_ar = usm_array_ptr();
1038 int nd = raw_ar->nd_;
1039 const py::ssize_t *shape = raw_ar->shape_;
1040 const py::ssize_t *strides = raw_ar->strides_;
1042 py::ssize_t offset_min = 0;
1043 py::ssize_t offset_max = 0;
1044 if (strides ==
nullptr) {
1045 py::ssize_t stride(1);
1046 for (
int i = 0; i < nd; ++i) {
1047 offset_max += stride * (shape[i] - 1);
1052 for (
int i = 0; i < nd; ++i) {
1053 py::ssize_t delta = strides[i] * (shape[i] - 1);
1054 if (strides[i] > 0) {
1055 offset_max += delta;
1058 offset_min += delta;
1062 return std::make_pair(offset_min, offset_max);
1065 sycl::queue get_queue()
const
1067 PyUSMArrayObject *raw_ar = usm_array_ptr();
1068 Py_MemoryObject *mem_obj =
1069 reinterpret_cast<Py_MemoryObject *
>(raw_ar->base_);
1071 auto const &api = ::dpctl::detail::dpctl_capi::get();
1072 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1073 return *(
reinterpret_cast<sycl::queue *
>(QRef));
1076 sycl::device get_device()
const
1078 PyUSMArrayObject *raw_ar = usm_array_ptr();
1079 Py_MemoryObject *mem_obj =
1080 reinterpret_cast<Py_MemoryObject *
>(raw_ar->base_);
1082 auto const &api = ::dpctl::detail::dpctl_capi::get();
1083 DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
1084 return reinterpret_cast<sycl::queue *
>(QRef)->get_device();
1087 int get_typenum()
const
1089 PyUSMArrayObject *raw_ar = usm_array_ptr();
1090 return raw_ar->typenum_;
1093 int get_flags()
const
1095 PyUSMArrayObject *raw_ar = usm_array_ptr();
1096 return raw_ar->flags_;
1099 int get_elemsize()
const
1101 int typenum = get_typenum();
1102 auto const &api = ::dpctl::detail::dpctl_capi::get();
1105 if (typenum == api.UAR_BOOL_)
1107 if (typenum == api.UAR_BYTE_)
1109 if (typenum == api.UAR_UBYTE_)
1111 if (typenum == api.UAR_SHORT_)
1113 if (typenum == api.UAR_USHORT_)
1115 if (typenum == api.UAR_INT_)
1117 if (typenum == api.UAR_UINT_)
1119 if (typenum == api.UAR_LONG_)
1120 return sizeof(long);
1121 if (typenum == api.UAR_ULONG_)
1122 return sizeof(
unsigned long);
1123 if (typenum == api.UAR_LONGLONG_)
1125 if (typenum == api.UAR_ULONGLONG_)
1127 if (typenum == api.UAR_FLOAT_)
1129 if (typenum == api.UAR_DOUBLE_)
1131 if (typenum == api.UAR_CFLOAT_)
1133 if (typenum == api.UAR_CDOUBLE_)
1135 if (typenum == api.UAR_HALF_)
1141 bool is_c_contiguous()
const
1143 int flags = get_flags();
1144 auto const &api = ::dpctl::detail::dpctl_capi::get();
1145 return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
1148 bool is_f_contiguous()
const
1150 int flags = get_flags();
1151 auto const &api = ::dpctl::detail::dpctl_capi::get();
1152 return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
1155 bool is_writable()
const
1157 int flags = get_flags();
1158 auto const &api = ::dpctl::detail::dpctl_capi::get();
1159 return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
1165 PyUSMArrayObject *raw_ar = usm_array_ptr();
1167 PyObject *usm_data = raw_ar->base_;
1168 Py_XINCREF(usm_data);
1171 return py::reinterpret_steal<py::object>(usm_data);
1174 bool is_managed_by_smart_ptr()
const
1176 PyUSMArrayObject *raw_ar = usm_array_ptr();
1177 PyObject *usm_data = raw_ar->base_;
1179 auto const &api = ::dpctl::detail::dpctl_capi::get();
1180 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1184 Py_MemoryObject *mem_obj =
1185 reinterpret_cast<Py_MemoryObject *
>(usm_data);
1186 const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1188 return bool(opaque_ptr);
1191 const std::shared_ptr<void> &get_smart_ptr_owner()
const
1193 PyUSMArrayObject *raw_ar = usm_array_ptr();
1194 PyObject *usm_data = raw_ar->base_;
1196 auto const &api = ::dpctl::detail::dpctl_capi::get();
1198 if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1199 throw std::runtime_error(
1200 "usm_ndarray object does not have Memory object "
1201 "managing lifetime of USM allocation");
1204 Py_MemoryObject *mem_obj =
1205 reinterpret_cast<Py_MemoryObject *
>(usm_data);
1206 void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1210 reinterpret_cast<std::shared_ptr<void> *
>(opaque_ptr);
1214 throw std::runtime_error(
1215 "Memory object underlying usm_ndarray does not have "
1216 "smart pointer managing lifetime of USM allocation");
1221 PyUSMArrayObject *usm_array_ptr()
const
1223 return reinterpret_cast<PyUSMArrayObject *
>(m_ptr);