30#include <oneapi/mkl.hpp> 
   31#include <sycl/sycl.hpp> 
   33#include <dpctl4pybind11.hpp> 
   34#include <pybind11/pybind11.h> 
   37#include "ext/common.hpp" 
   40#include "utils/memory_overlap.hpp" 
   41#include "utils/type_dispatch.hpp" 
   49#ifndef __INTEL_MKL_2023_2_0_VERSION_REQUIRED 
   50#define __INTEL_MKL_2023_2_0_VERSION_REQUIRED 20230002L 
   53static_assert(INTEL_MKL_VERSION >= __INTEL_MKL_2023_2_0_VERSION_REQUIRED,
 
   54              "OneMKL does not meet minimum version requirement");
 
   56namespace ext_ns = ext::common;
 
   57namespace py = pybind11;
 
   58namespace td_ns = dpctl::tensor::type_dispatch;
 
   60namespace dpnp::extensions::vm::py_internal
 
   62template <
typename output_typesT, 
typename contig_dispatchT>
 
   63bool need_to_call_unary_ufunc(sycl::queue &exec_q,
 
   64                              const dpctl::tensor::usm_ndarray &src,
 
   65                              const dpctl::tensor::usm_ndarray &dst,
 
   66                              const output_typesT &output_type_vec,
 
   67                              const contig_dispatchT &contig_dispatch_vector)
 
   70    int src_typenum = src.get_typenum();
 
   71    int dst_typenum = dst.get_typenum();
 
   73    auto array_types = td_ns::usm_ndarray_types();
 
   74    int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
 
   75    int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
 
   78    int func_output_typeid = output_type_vec[src_typeid];
 
   79    if (dst_typeid != func_output_typeid) {
 
   84    if (!exec_q.get_device().has(sycl::aspect::fp64)) {
 
   89    if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
 
   94    int dst_nd = dst.get_ndim();
 
   95    if (dst_nd != src.get_ndim()) {
 
   98    else if (dst_nd == 0) {
 
  104    const py::ssize_t *src_shape = src.get_shape_raw();
 
  105    const py::ssize_t *dst_shape = dst.get_shape_raw();
 
  106    bool shapes_equal(
true);
 
  107    size_t src_nelems(1);
 
  109    for (
int i = 0; i < dst_nd; ++i) {
 
  110        src_nelems *= 
static_cast<size_t>(src_shape[i]);
 
  111        shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
 
  118    if (src_nelems == 0) {
 
  123    auto dst_offsets = dst.get_minmax_offsets();
 
  127            static_cast<size_t>(dst_offsets.second - dst_offsets.first);
 
  128        if (range + 1 < src_nelems) {
 
  134    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  135    if (overlap(src, dst)) {
 
  140    bool is_src_c_contig = src.is_c_contiguous();
 
  141    bool is_dst_c_contig = dst.is_c_contiguous();
 
  143    bool all_c_contig = (is_src_c_contig && is_dst_c_contig);
 
  149    if (contig_dispatch_vector[src_typeid] == 
nullptr) {
 
  155template <
typename output_typesT, 
typename contig_dispatchT>
 
  156bool need_to_call_binary_ufunc(sycl::queue &exec_q,
 
  157                               const dpctl::tensor::usm_ndarray &src1,
 
  158                               const dpctl::tensor::usm_ndarray &src2,
 
  159                               const dpctl::tensor::usm_ndarray &dst,
 
  160                               const output_typesT &output_type_table,
 
  161                               const contig_dispatchT &contig_dispatch_table)
 
  164    int src1_typenum = src1.get_typenum();
 
  165    int src2_typenum = src2.get_typenum();
 
  166    int dst_typenum = dst.get_typenum();
 
  168    auto array_types = td_ns::usm_ndarray_types();
 
  169    int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
 
  170    int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
 
  171    int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
 
  174    int output_typeid = output_type_table[src1_typeid][src2_typeid];
 
  175    if (output_typeid != dst_typeid) {
 
  180    if (src1_typeid != src2_typeid) {
 
  185    if (!exec_q.get_device().has(sycl::aspect::fp64)) {
 
  190    if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
 
  195    int dst_nd = dst.get_ndim();
 
  196    if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
 
  199    else if (dst_nd == 0) {
 
  205    const py::ssize_t *src1_shape = src1.get_shape_raw();
 
  206    const py::ssize_t *src2_shape = src2.get_shape_raw();
 
  207    const py::ssize_t *dst_shape = dst.get_shape_raw();
 
  208    bool shapes_equal(
true);
 
  209    size_t src_nelems(1);
 
  211    for (
int i = 0; i < dst_nd; ++i) {
 
  212        src_nelems *= 
static_cast<size_t>(src1_shape[i]);
 
  213        shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
 
  214                                        src2_shape[i] == dst_shape[i]);
 
  221    if (src_nelems == 0) {
 
  226    auto dst_offsets = dst.get_minmax_offsets();
 
  230            static_cast<size_t>(dst_offsets.second - dst_offsets.first);
 
  231        if (range + 1 < src_nelems) {
 
  237    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  238    if (overlap(src1, dst) || overlap(src2, dst)) {
 
  243    bool is_src1_c_contig = src1.is_c_contiguous();
 
  244    bool is_src2_c_contig = src2.is_c_contiguous();
 
  245    bool is_dst_c_contig = dst.is_c_contiguous();
 
  248        (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
 
  254    if (contig_dispatch_table[src1_typeid] == 
nullptr) {
 
  265#define MACRO_POPULATE_DISPATCH_VECTORS(__name__)                              \ 
  266    template <typename fnT, typename T>                                        \ 
  267    struct ContigFactory                                                       \ 
  271            if constexpr (std::is_same_v<typename OutputType<T>::value_type,   \ 
  276                return __name__##_contig_impl<T>;                              \ 
  281    template <typename fnT, typename T>                                        \ 
  282    struct TypeMapFactory                                                      \ 
  284        std::enable_if_t<std::is_same<fnT, int>::value, int> get()             \ 
  286            using rT = typename OutputType<T>::value_type;                     \ 
  287            return td_ns::GetTypeid<rT>{}.get();                               \ 
  291    static void populate_dispatch_vectors(void)                                \ 
  293        ext_ns::init_dispatch_vector<int, TypeMapFactory>(                     \ 
  294            output_typeid_vector);                                             \ 
  295        ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t,               \ 
  296                                     ContigFactory>(contig_dispatch_vector);   \ 
  304#define MACRO_POPULATE_DISPATCH_TABLES(__name__)                               \ 
  305    template <typename fnT, typename T1, typename T2>                          \ 
  306    struct ContigFactory                                                       \ 
  310            if constexpr (std::is_same_v<                                      \ 
  311                              typename OutputType<T1, T2>::value_type, void>)  \ 
  316                return __name__##_contig_impl<T1, T2>;                         \ 
  321    template <typename fnT, typename T1, typename T2>                          \ 
  322    struct TypeMapFactory                                                      \ 
  324        std::enable_if_t<std::is_same<fnT, int>::value, int> get()             \ 
  326            using rT = typename OutputType<T1, T2>::value_type;                \ 
  327            return td_ns::GetTypeid<rT>{}.get();                               \ 
  331    static void populate_dispatch_tables(void)                                 \ 
  333        ext_ns::init_dispatch_table<int, TypeMapFactory>(                      \ 
  334            output_typeid_vector);                                             \ 
  335        ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t,               \ 
  336                                    ContigFactory>(contig_dispatch_vector);    \