36#include <oneapi/mkl.hpp>
37#include <sycl/sycl.hpp>
39#include <dpctl4pybind11.hpp>
40#include <pybind11/pybind11.h>
43#include "ext/common.hpp"
46#include "utils/memory_overlap.hpp"
47#include "utils/type_dispatch.hpp"
55#ifndef __INTEL_MKL_2023_2_0_VERSION_REQUIRED
56#define __INTEL_MKL_2023_2_0_VERSION_REQUIRED 20230002L
59static_assert(INTEL_MKL_VERSION >= __INTEL_MKL_2023_2_0_VERSION_REQUIRED,
60 "OneMKL does not meet minimum version requirement");
62namespace ext_ns = ext::common;
63namespace py = pybind11;
64namespace td_ns = dpctl::tensor::type_dispatch;
66namespace dpnp::extensions::vm::py_internal
68template <
typename output_typesT,
typename contig_dispatchT>
69bool need_to_call_unary_ufunc(sycl::queue &exec_q,
70 const dpctl::tensor::usm_ndarray &src,
71 const dpctl::tensor::usm_ndarray &dst,
72 const output_typesT &output_type_vec,
73 const contig_dispatchT &contig_dispatch_vector)
76 int src_typenum = src.get_typenum();
77 int dst_typenum = dst.get_typenum();
79 auto array_types = td_ns::usm_ndarray_types();
80 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
81 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
84 int func_output_typeid = output_type_vec[src_typeid];
85 if (dst_typeid != func_output_typeid) {
90 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
95 if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
100 int dst_nd = dst.get_ndim();
101 if (dst_nd != src.get_ndim()) {
104 else if (dst_nd == 0) {
110 const py::ssize_t *src_shape = src.get_shape_raw();
111 const py::ssize_t *dst_shape = dst.get_shape_raw();
112 bool shapes_equal(
true);
113 size_t src_nelems(1);
115 for (
int i = 0; i < dst_nd; ++i) {
116 src_nelems *=
static_cast<size_t>(src_shape[i]);
117 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
124 if (src_nelems == 0) {
129 auto dst_offsets = dst.get_minmax_offsets();
133 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
134 if (range + 1 < src_nelems) {
140 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
141 if (overlap(src, dst)) {
146 bool is_src_c_contig = src.is_c_contiguous();
147 bool is_dst_c_contig = dst.is_c_contiguous();
149 bool all_c_contig = (is_src_c_contig && is_dst_c_contig);
155 if (contig_dispatch_vector[src_typeid] ==
nullptr) {
161template <
typename output_typesT,
typename contig_dispatchT>
162bool need_to_call_unary_two_outputs_ufunc(
164 const dpctl::tensor::usm_ndarray &src,
165 const dpctl::tensor::usm_ndarray &dst1,
166 const dpctl::tensor::usm_ndarray &dst2,
167 const output_typesT &output_type_vec,
168 const contig_dispatchT &contig_dispatch_vector)
171 int src_typenum = src.get_typenum();
172 int dst1_typenum = dst1.get_typenum();
173 int dst2_typenum = dst2.get_typenum();
175 const auto &array_types = td_ns::usm_ndarray_types();
176 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
177 int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
178 int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
180 std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
183 if (dst1_typeid != func_output_typeids.first ||
184 dst2_typeid != func_output_typeids.second) {
189 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
194 if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst1, dst2})) {
199 int src_nd = src.get_ndim();
200 int dst1_nd = dst1.get_ndim();
201 int dst2_nd = dst2.get_ndim();
202 if (src_nd != dst1_nd || src_nd != dst2_nd) {
205 else if (dst1_nd == 0 || dst2_nd == 0) {
211 const py::ssize_t *src_shape = src.get_shape_raw();
212 const py::ssize_t *dst1_shape = dst1.get_shape_raw();
213 const py::ssize_t *dst2_shape = dst2.get_shape_raw();
214 bool shapes_equal(
true);
215 size_t src_nelems(1);
217 for (
int i = 0; i < src_nd; ++i) {
218 src_nelems *=
static_cast<std::size_t
>(src_shape[i]);
219 shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
220 (src_shape[i] == dst2_shape[i]);
227 if (src_nelems == 0) {
232 auto dst1_offsets = dst1.get_minmax_offsets();
233 auto dst2_offsets = dst2.get_minmax_offsets();
237 static_cast<size_t>(dst1_offsets.second - dst1_offsets.first);
239 static_cast<size_t>(dst2_offsets.second - dst2_offsets.first);
240 if ((range1 + 1 < src_nelems) || (range2 + 1 < src_nelems)) {
246 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
247 if (overlap(src, dst1) || overlap(src, dst2) || overlap(dst1, dst2)) {
252 bool is_src_c_contig = src.is_c_contiguous();
253 bool is_dst1_c_contig = dst1.is_c_contiguous();
254 bool is_dst2_c_contig = dst2.is_c_contiguous();
257 (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
263 if (contig_dispatch_vector[src_typeid] ==
nullptr) {
269template <
typename output_typesT,
typename contig_dispatchT>
270bool need_to_call_binary_ufunc(sycl::queue &exec_q,
271 const dpctl::tensor::usm_ndarray &src1,
272 const dpctl::tensor::usm_ndarray &src2,
273 const dpctl::tensor::usm_ndarray &dst,
274 const output_typesT &output_type_table,
275 const contig_dispatchT &contig_dispatch_table)
278 int src1_typenum = src1.get_typenum();
279 int src2_typenum = src2.get_typenum();
280 int dst_typenum = dst.get_typenum();
282 auto array_types = td_ns::usm_ndarray_types();
283 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
284 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
285 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
288 int output_typeid = output_type_table[src1_typeid][src2_typeid];
289 if (output_typeid != dst_typeid) {
294 if (src1_typeid != src2_typeid) {
299 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
304 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
309 int dst_nd = dst.get_ndim();
310 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
313 else if (dst_nd == 0) {
319 const py::ssize_t *src1_shape = src1.get_shape_raw();
320 const py::ssize_t *src2_shape = src2.get_shape_raw();
321 const py::ssize_t *dst_shape = dst.get_shape_raw();
322 bool shapes_equal(
true);
323 size_t src_nelems(1);
325 for (
int i = 0; i < dst_nd; ++i) {
326 src_nelems *=
static_cast<size_t>(src1_shape[i]);
327 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
328 src2_shape[i] == dst_shape[i]);
335 if (src_nelems == 0) {
340 auto dst_offsets = dst.get_minmax_offsets();
344 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
345 if (range + 1 < src_nelems) {
351 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
352 if (overlap(src1, dst) || overlap(src2, dst)) {
357 bool is_src1_c_contig = src1.is_c_contiguous();
358 bool is_src2_c_contig = src2.is_c_contiguous();
359 bool is_dst_c_contig = dst.is_c_contiguous();
362 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
368 if (contig_dispatch_table[src1_typeid] ==
nullptr) {
379#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
380 template <typename fnT, typename T> \
381 struct ContigFactory \
385 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
390 return __name__##_contig_impl<T>; \
395 template <typename fnT, typename T> \
396 struct TypeMapFactory \
398 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
400 using rT = typename OutputType<T>::value_type; \
401 return td_ns::GetTypeid<rT>{}.get(); \
405 static void populate_dispatch_vectors(void) \
407 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
408 output_typeid_vector); \
409 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
410 ContigFactory>(contig_dispatch_vector); \
418#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
419 template <typename fnT, typename T> \
420 struct ContigFactory \
424 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
426 std::is_same_v<typename OutputType<T>::value_type2, \
432 fnT fn = __name__##_contig_impl<T>; \
438 template <typename fnT, typename T> \
439 struct TypeMapFactory \
441 std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
442 std::pair<int, int>> \
445 using rT1 = typename OutputType<T>::value_type1; \
446 using rT2 = typename OutputType<T>::value_type2; \
447 return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
448 td_ns::GetTypeid<rT2>{}.get()); \
452 static void populate_dispatch_vectors(void) \
454 ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
455 output_typeid_vector); \
456 ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
457 ContigFactory>(contig_dispatch_vector); \
465#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
466 template <typename fnT, typename T1, typename T2> \
467 struct ContigFactory \
471 if constexpr (std::is_same_v< \
472 typename OutputType<T1, T2>::value_type, \
477 return __name__##_contig_impl<T1, T2>; \
482 template <typename fnT, typename T1, typename T2> \
483 struct TypeMapFactory \
485 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
487 using rT = typename OutputType<T1, T2>::value_type; \
488 return td_ns::GetTypeid<rT>{}.get(); \
492 static void populate_dispatch_tables(void) \
494 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
495 output_typeid_vector); \
496 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
497 ContigFactory>(contig_dispatch_vector); \