28#include <oneapi/mkl.hpp>
29#include <sycl/sycl.hpp>
31#include <dpctl4pybind11.hpp>
32#include <pybind11/pybind11.h>
35#include "utils/memory_overlap.hpp"
36#include "utils/type_dispatch.hpp"
38#include "dpnp_utils.hpp"
40static_assert(INTEL_MKL_VERSION >= __INTEL_MKL_2023_2_0_VERSION_REQUIRED,
41 "OneMKL does not meet minimum version requirement");
43namespace py = pybind11;
44namespace td_ns = dpctl::tensor::type_dispatch;
46namespace dpnp::extensions::vm::py_internal
48template <
typename output_typesT,
typename contig_dispatchT>
49bool need_to_call_unary_ufunc(sycl::queue &exec_q,
50 const dpctl::tensor::usm_ndarray &src,
51 const dpctl::tensor::usm_ndarray &dst,
52 const output_typesT &output_type_vec,
53 const contig_dispatchT &contig_dispatch_vector)
56 int src_typenum = src.get_typenum();
57 int dst_typenum = dst.get_typenum();
59 auto array_types = td_ns::usm_ndarray_types();
60 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
61 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
64 int func_output_typeid = output_type_vec[src_typeid];
65 if (dst_typeid != func_output_typeid) {
70 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
75 if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
80 int dst_nd = dst.get_ndim();
81 if (dst_nd != src.get_ndim()) {
84 else if (dst_nd == 0) {
90 const py::ssize_t *src_shape = src.get_shape_raw();
91 const py::ssize_t *dst_shape = dst.get_shape_raw();
92 bool shapes_equal(
true);
95 for (
int i = 0; i < dst_nd; ++i) {
96 src_nelems *=
static_cast<size_t>(src_shape[i]);
97 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
104 if (src_nelems == 0) {
109 auto dst_offsets = dst.get_minmax_offsets();
113 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
114 if (range + 1 < src_nelems) {
120 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
121 if (overlap(src, dst)) {
126 bool is_src_c_contig = src.is_c_contiguous();
127 bool is_dst_c_contig = dst.is_c_contiguous();
129 bool all_c_contig = (is_src_c_contig && is_dst_c_contig);
135 if (contig_dispatch_vector[src_typeid] ==
nullptr) {
141template <
typename output_typesT,
typename contig_dispatchT>
142bool need_to_call_binary_ufunc(sycl::queue &exec_q,
143 const dpctl::tensor::usm_ndarray &src1,
144 const dpctl::tensor::usm_ndarray &src2,
145 const dpctl::tensor::usm_ndarray &dst,
146 const output_typesT &output_type_table,
147 const contig_dispatchT &contig_dispatch_table)
150 int src1_typenum = src1.get_typenum();
151 int src2_typenum = src2.get_typenum();
152 int dst_typenum = dst.get_typenum();
154 auto array_types = td_ns::usm_ndarray_types();
155 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
156 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
157 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
160 int output_typeid = output_type_table[src1_typeid][src2_typeid];
161 if (output_typeid != dst_typeid) {
166 if (src1_typeid != src2_typeid) {
171 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
176 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
181 int dst_nd = dst.get_ndim();
182 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
185 else if (dst_nd == 0) {
191 const py::ssize_t *src1_shape = src1.get_shape_raw();
192 const py::ssize_t *src2_shape = src2.get_shape_raw();
193 const py::ssize_t *dst_shape = dst.get_shape_raw();
194 bool shapes_equal(
true);
195 size_t src_nelems(1);
197 for (
int i = 0; i < dst_nd; ++i) {
198 src_nelems *=
static_cast<size_t>(src1_shape[i]);
199 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
200 src2_shape[i] == dst_shape[i]);
207 if (src_nelems == 0) {
212 auto dst_offsets = dst.get_minmax_offsets();
216 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
217 if (range + 1 < src_nelems) {
223 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
224 if (overlap(src1, dst) || overlap(src2, dst)) {
229 bool is_src1_c_contig = src1.is_c_contiguous();
230 bool is_src2_c_contig = src2.is_c_contiguous();
231 bool is_dst_c_contig = dst.is_c_contiguous();
234 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
240 if (contig_dispatch_table[src1_typeid] ==
nullptr) {
251#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
252 template <typename fnT, typename T> \
253 struct ContigFactory \
257 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
262 return __name__##_contig_impl<T>; \
267 template <typename fnT, typename T> \
268 struct TypeMapFactory \
270 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
272 using rT = typename OutputType<T>::value_type; \
273 return td_ns::GetTypeid<rT>{}.get(); \
277 static void populate_dispatch_vectors(void) \
279 py_internal::init_ufunc_dispatch_vector<int, TypeMapFactory>( \
280 output_typeid_vector); \
281 py_internal::init_ufunc_dispatch_vector<unary_contig_impl_fn_ptr_t, \
283 contig_dispatch_vector); \
291#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
292 template <typename fnT, typename T1, typename T2> \
293 struct ContigFactory \
297 if constexpr (std::is_same_v< \
298 typename OutputType<T1, T2>::value_type, void>) \
303 return __name__##_contig_impl<T1, T2>; \
308 template <typename fnT, typename T1, typename T2> \
309 struct TypeMapFactory \
311 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
313 using rT = typename OutputType<T1, T2>::value_type; \
314 return td_ns::GetTypeid<rT>{}.get(); \
318 static void populate_dispatch_tables(void) \
320 py_internal::init_ufunc_dispatch_table<int, TypeMapFactory>( \
321 output_typeid_vector); \
322 py_internal::init_ufunc_dispatch_table<binary_contig_impl_fn_ptr_t, \
324 contig_dispatch_vector); \
327template <
typename dispatchT,
328 template <
typename fnT,
typename T>
330 int _num_types = td_ns::num_types>
331void init_ufunc_dispatch_vector(dispatchT dispatch_vector[])
333 td_ns::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
334 dvb.populate_dispatch_vector(dispatch_vector);
337template <
typename dispatchT,
338 template <
typename fnT,
typename D,
typename S>
340 int _num_types = td_ns::num_types>
341void init_ufunc_dispatch_table(dispatchT dispatch_table[][_num_types])
343 td_ns::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
344 dtb.populate_dispatch_table(dispatch_table);