33#include <oneapi/mkl.hpp>
34#include <sycl/sycl.hpp>
36#include <dpctl4pybind11.hpp>
37#include <pybind11/pybind11.h>
40#include "ext/common.hpp"
43#include "utils/memory_overlap.hpp"
44#include "utils/type_dispatch.hpp"
52#ifndef __INTEL_MKL_2023_2_0_VERSION_REQUIRED
53#define __INTEL_MKL_2023_2_0_VERSION_REQUIRED 20230002L
56static_assert(INTEL_MKL_VERSION >= __INTEL_MKL_2023_2_0_VERSION_REQUIRED,
57 "OneMKL does not meet minimum version requirement");
59namespace ext_ns = ext::common;
60namespace py = pybind11;
61namespace td_ns = dpctl::tensor::type_dispatch;
63namespace dpnp::extensions::vm::py_internal
65template <
typename output_typesT,
typename contig_dispatchT>
66bool need_to_call_unary_ufunc(sycl::queue &exec_q,
67 const dpctl::tensor::usm_ndarray &src,
68 const dpctl::tensor::usm_ndarray &dst,
69 const output_typesT &output_type_vec,
70 const contig_dispatchT &contig_dispatch_vector)
73 int src_typenum = src.get_typenum();
74 int dst_typenum = dst.get_typenum();
76 auto array_types = td_ns::usm_ndarray_types();
77 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
78 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
81 int func_output_typeid = output_type_vec[src_typeid];
82 if (dst_typeid != func_output_typeid) {
87 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
92 if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
97 int dst_nd = dst.get_ndim();
98 if (dst_nd != src.get_ndim()) {
101 else if (dst_nd == 0) {
107 const py::ssize_t *src_shape = src.get_shape_raw();
108 const py::ssize_t *dst_shape = dst.get_shape_raw();
109 bool shapes_equal(
true);
110 size_t src_nelems(1);
112 for (
int i = 0; i < dst_nd; ++i) {
113 src_nelems *=
static_cast<size_t>(src_shape[i]);
114 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
121 if (src_nelems == 0) {
126 auto dst_offsets = dst.get_minmax_offsets();
130 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
131 if (range + 1 < src_nelems) {
137 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
138 if (overlap(src, dst)) {
143 bool is_src_c_contig = src.is_c_contiguous();
144 bool is_dst_c_contig = dst.is_c_contiguous();
146 bool all_c_contig = (is_src_c_contig && is_dst_c_contig);
152 if (contig_dispatch_vector[src_typeid] ==
nullptr) {
158template <
typename output_typesT,
typename contig_dispatchT>
159bool need_to_call_binary_ufunc(sycl::queue &exec_q,
160 const dpctl::tensor::usm_ndarray &src1,
161 const dpctl::tensor::usm_ndarray &src2,
162 const dpctl::tensor::usm_ndarray &dst,
163 const output_typesT &output_type_table,
164 const contig_dispatchT &contig_dispatch_table)
167 int src1_typenum = src1.get_typenum();
168 int src2_typenum = src2.get_typenum();
169 int dst_typenum = dst.get_typenum();
171 auto array_types = td_ns::usm_ndarray_types();
172 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
173 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
174 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
177 int output_typeid = output_type_table[src1_typeid][src2_typeid];
178 if (output_typeid != dst_typeid) {
183 if (src1_typeid != src2_typeid) {
188 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
193 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
198 int dst_nd = dst.get_ndim();
199 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
202 else if (dst_nd == 0) {
208 const py::ssize_t *src1_shape = src1.get_shape_raw();
209 const py::ssize_t *src2_shape = src2.get_shape_raw();
210 const py::ssize_t *dst_shape = dst.get_shape_raw();
211 bool shapes_equal(
true);
212 size_t src_nelems(1);
214 for (
int i = 0; i < dst_nd; ++i) {
215 src_nelems *=
static_cast<size_t>(src1_shape[i]);
216 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
217 src2_shape[i] == dst_shape[i]);
224 if (src_nelems == 0) {
229 auto dst_offsets = dst.get_minmax_offsets();
233 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
234 if (range + 1 < src_nelems) {
240 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
241 if (overlap(src1, dst) || overlap(src2, dst)) {
246 bool is_src1_c_contig = src1.is_c_contiguous();
247 bool is_src2_c_contig = src2.is_c_contiguous();
248 bool is_dst_c_contig = dst.is_c_contiguous();
251 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
257 if (contig_dispatch_table[src1_typeid] ==
nullptr) {
268#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
269 template <typename fnT, typename T> \
270 struct ContigFactory \
274 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
279 return __name__##_contig_impl<T>; \
284 template <typename fnT, typename T> \
285 struct TypeMapFactory \
287 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
289 using rT = typename OutputType<T>::value_type; \
290 return td_ns::GetTypeid<rT>{}.get(); \
294 static void populate_dispatch_vectors(void) \
296 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
297 output_typeid_vector); \
298 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
299 ContigFactory>(contig_dispatch_vector); \
307#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
308 template <typename fnT, typename T1, typename T2> \
309 struct ContigFactory \
313 if constexpr (std::is_same_v< \
314 typename OutputType<T1, T2>::value_type, void>) \
319 return __name__##_contig_impl<T1, T2>; \
324 template <typename fnT, typename T1, typename T2> \
325 struct TypeMapFactory \
327 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
329 using rT = typename OutputType<T1, T2>::value_type; \
330 return td_ns::GetTypeid<rT>{}.get(); \
334 static void populate_dispatch_tables(void) \
336 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
337 output_typeid_vector); \
338 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
339 ContigFactory>(contig_dispatch_vector); \