34#include <sycl/sycl.hpp> 
   36#include "dpctl4pybind11.hpp" 
   37#include <pybind11/numpy.h> 
   38#include <pybind11/pybind11.h> 
   39#include <pybind11/stl.h> 
   41#include "elementwise_functions_type_utils.hpp" 
   42#include "simplify_iteration_space.hpp" 
   45#include "kernels/alignment.hpp" 
   47#include "utils/memory_overlap.hpp" 
   48#include "utils/offset_utils.hpp" 
   49#include "utils/output_validation.hpp" 
   50#include "utils/sycl_alloc_utils.hpp" 
   51#include "utils/type_dispatch.hpp" 
   53namespace py = pybind11;
 
   54namespace td_ns = dpctl::tensor::type_dispatch;
 
   56static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
 
   58namespace dpnp::extensions::py_internal
 
   61using dpctl::tensor::kernels::alignment_utils::is_aligned;
 
   62using dpctl::tensor::kernels::alignment_utils::required_alignment;
 
   65template <
typename output_typesT,
 
   66          typename contig_dispatchT,
 
   67          typename strided_dispatchT>
 
   68std::pair<sycl::event, sycl::event>
 
   69    py_unary_ufunc(
const dpctl::tensor::usm_ndarray &src,
 
   70                   const dpctl::tensor::usm_ndarray &dst,
 
   72                   const std::vector<sycl::event> &depends,
 
   74                   const output_typesT &output_type_vec,
 
   75                   const contig_dispatchT &contig_dispatch_vector,
 
   76                   const strided_dispatchT &strided_dispatch_vector)
 
   78    int src_typenum = src.get_typenum();
 
   79    int dst_typenum = dst.get_typenum();
 
   81    const auto &array_types = td_ns::usm_ndarray_types();
 
   82    int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
 
   83    int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
 
   85    int func_output_typeid = output_type_vec[src_typeid];
 
   88    if (dst_typeid != func_output_typeid) {
 
   89        throw py::value_error(
 
   90            "Destination array has unexpected elemental data type.");
 
   94    if (!dpctl::utils::queues_are_compatible(q, {src, dst})) {
 
   95        throw py::value_error(
 
   96            "Execution queue is not compatible with allocation queues");
 
   99    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
 
  102    int src_nd = src.get_ndim();
 
  103    if (src_nd != dst.get_ndim()) {
 
  104        throw py::value_error(
"Array dimensions are not the same.");
 
  108    const py::ssize_t *src_shape = src.get_shape_raw();
 
  109    const py::ssize_t *dst_shape = dst.get_shape_raw();
 
  110    bool shapes_equal(
true);
 
  111    size_t src_nelems(1);
 
  113    for (
int i = 0; i < src_nd; ++i) {
 
  114        src_nelems *= 
static_cast<size_t>(src_shape[i]);
 
  115        shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
 
  118        throw py::value_error(
"Array shapes are not the same.");
 
  122    if (src_nelems == 0) {
 
  123        return std::make_pair(sycl::event(), sycl::event());
 
  126    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
 
  129    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  130    auto const &same_logical_tensors =
 
  131        dpctl::tensor::overlap::SameLogicalTensors();
 
  132    if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
 
  133        throw py::value_error(
"Arrays index overlapping segments of memory");
 
  136    const char *src_data = src.get_data();
 
  137    char *dst_data = dst.get_data();
 
  140    bool is_src_c_contig = src.is_c_contiguous();
 
  141    bool is_src_f_contig = src.is_f_contiguous();
 
  143    bool is_dst_c_contig = dst.is_c_contiguous();
 
  144    bool is_dst_f_contig = dst.is_f_contiguous();
 
  146    bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
 
  147    bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
 
  149    if (both_c_contig || both_f_contig) {
 
  150        auto contig_fn = contig_dispatch_vector[src_typeid];
 
  152        if (contig_fn == 
nullptr) {
 
  153            throw std::runtime_error(
 
  154                "Contiguous implementation is missing for src_typeid=" +
 
  155                std::to_string(src_typeid));
 
  158        auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends);
 
  160            dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
 
  162        return std::make_pair(ht_ev, comp_ev);
 
  169    auto const &src_strides = src.get_strides_vector();
 
  170    auto const &dst_strides = dst.get_strides_vector();
 
  172    using shT = std::vector<py::ssize_t>;
 
  173    shT simplified_shape;
 
  174    shT simplified_src_strides;
 
  175    shT simplified_dst_strides;
 
  176    py::ssize_t src_offset(0);
 
  177    py::ssize_t dst_offset(0);
 
  180    const py::ssize_t *shape = src_shape;
 
  182    simplify_iteration_space(nd, shape, src_strides, dst_strides,
 
  184                             simplified_shape, simplified_src_strides,
 
  185                             simplified_dst_strides, src_offset, dst_offset);
 
  187    if (nd == 1 && simplified_src_strides[0] == 1 &&
 
  188        simplified_dst_strides[0] == 1) {
 
  190        auto contig_fn = contig_dispatch_vector[src_typeid];
 
  192        if (contig_fn == 
nullptr) {
 
  193            throw std::runtime_error(
 
  194                "Contiguous implementation is missing for src_typeid=" +
 
  195                std::to_string(src_typeid));
 
  198        int src_elem_size = src.get_elemsize();
 
  199        int dst_elem_size = dst.get_elemsize();
 
  201            contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
 
  202                      dst_data + dst_elem_size * dst_offset, depends);
 
  205            dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
 
  207        return std::make_pair(ht_ev, comp_ev);
 
  211    auto strided_fn = strided_dispatch_vector[src_typeid];
 
  213    if (strided_fn == 
nullptr) {
 
  214        throw std::runtime_error(
 
  215            "Strided implementation is missing for src_typeid=" +
 
  216            std::to_string(src_typeid));
 
  219    using dpctl::tensor::offset_utils::device_allocate_and_pack;
 
  221    std::vector<sycl::event> host_tasks{};
 
  222    host_tasks.reserve(2);
 
  224    auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
 
  225        q, host_tasks, simplified_shape, simplified_src_strides,
 
  226        simplified_dst_strides);
 
  227    auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
 
  228    const auto ©_shape_ev = std::get<2>(ptr_size_event_triple_);
 
  229    const py::ssize_t *shape_strides = shape_strides_owner.get();
 
  231    sycl::event strided_fn_ev =
 
  232        strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
 
  233                   dst_data, dst_offset, depends, {copy_shape_ev});
 
  236    sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
 
  237        q, {strided_fn_ev}, shape_strides_owner);
 
  239    host_tasks.push_back(tmp_cleanup_ev);
 
  241    return std::make_pair(
 
  242        dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks),
 
  248template <
typename output_typesT>
 
  249py::object py_unary_ufunc_result_type(
const py::dtype &input_dtype,
 
  250                                      const output_typesT &output_types)
 
  252    int tn = input_dtype.num(); 
 
  255    auto array_types = td_ns::usm_ndarray_types();
 
  258        src_typeid = array_types.typenum_to_lookup_id(tn);
 
  259    } 
catch (
const std::exception &e) {
 
  260        throw py::value_error(e.what());
 
  263    using type_utils::_result_typeid;
 
  264    int dst_typeid = _result_typeid(src_typeid, output_types);
 
  266    if (dst_typeid < 0) {
 
  267        auto res = py::none();
 
  268        return py::cast<py::object>(res);
 
  271        using type_utils::_dtype_from_typenum;
 
  273        auto dst_typenum_t = 
static_cast<td_ns::typenum_t
>(dst_typeid);
 
  274        auto dt = _dtype_from_typenum(dst_typenum_t);
 
  276        return py::cast<py::object>(dt);
 
  284template <
class Container, 
class T>
 
  285bool isEqual(Container 
const &c, std::initializer_list<T> 
const &l)
 
  287    return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l));
 
  293template <
typename output_typesT,
 
  294          typename contig_dispatchT,
 
  295          typename strided_dispatchT,
 
  296          typename contig_matrix_row_dispatchT,
 
  297          typename contig_row_matrix_dispatchT>
 
  298std::pair<sycl::event, sycl::event> py_binary_ufunc(
 
  299    const dpctl::tensor::usm_ndarray &src1,
 
  300    const dpctl::tensor::usm_ndarray &src2,
 
  301    const dpctl::tensor::usm_ndarray &dst, 
 
  303    const std::vector<sycl::event> depends,
 
  305    const output_typesT &output_type_table,
 
  306    const contig_dispatchT &contig_dispatch_table,
 
  307    const strided_dispatchT &strided_dispatch_table,
 
  308    const contig_matrix_row_dispatchT
 
  309        &contig_matrix_row_broadcast_dispatch_table,
 
  310    const contig_row_matrix_dispatchT
 
  311        &contig_row_matrix_broadcast_dispatch_table)
 
  314    int src1_typenum = src1.get_typenum();
 
  315    int src2_typenum = src2.get_typenum();
 
  316    int dst_typenum = dst.get_typenum();
 
  318    auto array_types = td_ns::usm_ndarray_types();
 
  319    int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
 
  320    int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
 
  321    int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
 
  323    int output_typeid = output_type_table[src1_typeid][src2_typeid];
 
  325    if (output_typeid != dst_typeid) {
 
  326        throw py::value_error(
 
  327            "Destination array has unexpected elemental data type.");
 
  331    if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
 
  332        throw py::value_error(
 
  333            "Execution queue is not compatible with allocation queues");
 
  336    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
 
  340    int dst_nd = dst.get_ndim();
 
  341    if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
 
  342        throw py::value_error(
"Array dimensions are not the same.");
 
  346    const py::ssize_t *src1_shape = src1.get_shape_raw();
 
  347    const py::ssize_t *src2_shape = src2.get_shape_raw();
 
  348    const py::ssize_t *dst_shape = dst.get_shape_raw();
 
  349    bool shapes_equal(
true);
 
  350    size_t src_nelems(1);
 
  352    for (
int i = 0; i < dst_nd; ++i) {
 
  353        src_nelems *= 
static_cast<size_t>(src1_shape[i]);
 
  354        shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
 
  355                                        src2_shape[i] == dst_shape[i]);
 
  358        throw py::value_error(
"Array shapes are not the same.");
 
  362    if (src_nelems == 0) {
 
  363        return std::make_pair(sycl::event(), sycl::event());
 
  366    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
 
  368    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  369    auto const &same_logical_tensors =
 
  370        dpctl::tensor::overlap::SameLogicalTensors();
 
  371    if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) ||
 
  372        (overlap(src2, dst) && !same_logical_tensors(src2, dst)))
 
  374        throw py::value_error(
"Arrays index overlapping segments of memory");
 
  377    const char *src1_data = src1.get_data();
 
  378    const char *src2_data = src2.get_data();
 
  379    char *dst_data = dst.get_data();
 
  382    bool is_src1_c_contig = src1.is_c_contiguous();
 
  383    bool is_src1_f_contig = src1.is_f_contiguous();
 
  385    bool is_src2_c_contig = src2.is_c_contiguous();
 
  386    bool is_src2_f_contig = src2.is_f_contiguous();
 
  388    bool is_dst_c_contig = dst.is_c_contiguous();
 
  389    bool is_dst_f_contig = dst.is_f_contiguous();
 
  392        (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
 
  394        (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig);
 
  397    if (all_c_contig || all_f_contig) {
 
  398        auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
 
  400        if (contig_fn != 
nullptr) {
 
  401            auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0,
 
  402                                     src2_data, 0, dst_data, 0, depends);
 
  403            sycl::event ht_ev = dpctl::utils::keep_args_alive(
 
  404                exec_q, {src1, src2, dst}, {comp_ev});
 
  406            return std::make_pair(ht_ev, comp_ev);
 
  411    auto const &src1_strides = src1.get_strides_vector();
 
  412    auto const &src2_strides = src2.get_strides_vector();
 
  413    auto const &dst_strides = dst.get_strides_vector();
 
  415    using shT = std::vector<py::ssize_t>;
 
  416    shT simplified_shape;
 
  417    shT simplified_src1_strides;
 
  418    shT simplified_src2_strides;
 
  419    shT simplified_dst_strides;
 
  420    py::ssize_t src1_offset(0);
 
  421    py::ssize_t src2_offset(0);
 
  422    py::ssize_t dst_offset(0);
 
  425    const py::ssize_t *shape = src1_shape;
 
  427    simplify_iteration_space_3(
 
  428        nd, shape, src1_strides, src2_strides, dst_strides,
 
  430        simplified_shape, simplified_src1_strides, simplified_src2_strides,
 
  431        simplified_dst_strides, src1_offset, src2_offset, dst_offset);
 
  433    std::vector<sycl::event> host_tasks{};
 
  435        static constexpr auto unit_stride =
 
  436            std::initializer_list<py::ssize_t>{1};
 
  438        if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) &&
 
  439            isEqual(simplified_src2_strides, unit_stride) &&
 
  440            isEqual(simplified_dst_strides, unit_stride))
 
  442            auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid];
 
  444            if (contig_fn != 
nullptr) {
 
  445                auto comp_ev = contig_fn(exec_q, src_nelems, src1_data,
 
  446                                         src1_offset, src2_data, src2_offset,
 
  447                                         dst_data, dst_offset, depends);
 
  448                sycl::event ht_ev = dpctl::utils::keep_args_alive(
 
  449                    exec_q, {src1, src2, dst}, {comp_ev});
 
  451                return std::make_pair(ht_ev, comp_ev);
 
  455            static constexpr auto zero_one_strides =
 
  456                std::initializer_list<py::ssize_t>{0, 1};
 
  457            static constexpr auto one_zero_strides =
 
  458                std::initializer_list<py::ssize_t>{1, 0};
 
  459            constexpr py::ssize_t one{1};
 
  461            if (isEqual(simplified_src2_strides, zero_one_strides) &&
 
  462                isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
 
  463                isEqual(simplified_dst_strides, {simplified_shape[1], one}))
 
  465                auto matrix_row_broadcast_fn =
 
  466                    contig_matrix_row_broadcast_dispatch_table[src1_typeid]
 
  468                if (matrix_row_broadcast_fn != 
nullptr) {
 
  469                    int src1_itemsize = src1.get_elemsize();
 
  470                    int src2_itemsize = src2.get_elemsize();
 
  471                    int dst_itemsize = dst.get_elemsize();
 
  473                    if (is_aligned<required_alignment>(
 
  474                            src1_data + src1_offset * src1_itemsize) &&
 
  475                        is_aligned<required_alignment>(
 
  476                            src2_data + src2_offset * src2_itemsize) &&
 
  477                        is_aligned<required_alignment>(
 
  478                            dst_data + dst_offset * dst_itemsize))
 
  480                        size_t n0 = simplified_shape[0];
 
  481                        size_t n1 = simplified_shape[1];
 
  482                        sycl::event comp_ev = matrix_row_broadcast_fn(
 
  483                            exec_q, host_tasks, n0, n1, src1_data, src1_offset,
 
  484                            src2_data, src2_offset, dst_data, dst_offset,
 
  487                        return std::make_pair(
 
  488                            dpctl::utils::keep_args_alive(
 
  489                                exec_q, {src1, src2, dst}, host_tasks),
 
  494            if (isEqual(simplified_src1_strides, one_zero_strides) &&
 
  495                isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
 
  496                isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
 
  498                auto row_matrix_broadcast_fn =
 
  499                    contig_row_matrix_broadcast_dispatch_table[src1_typeid]
 
  501                if (row_matrix_broadcast_fn != 
nullptr) {
 
  503                    int src1_itemsize = src1.get_elemsize();
 
  504                    int src2_itemsize = src2.get_elemsize();
 
  505                    int dst_itemsize = dst.get_elemsize();
 
  507                    if (is_aligned<required_alignment>(
 
  508                            src1_data + src1_offset * src1_itemsize) &&
 
  509                        is_aligned<required_alignment>(
 
  510                            src2_data + src2_offset * src2_itemsize) &&
 
  511                        is_aligned<required_alignment>(
 
  512                            dst_data + dst_offset * dst_itemsize))
 
  514                        size_t n0 = simplified_shape[1];
 
  515                        size_t n1 = simplified_shape[0];
 
  516                        sycl::event comp_ev = row_matrix_broadcast_fn(
 
  517                            exec_q, host_tasks, n0, n1, src1_data, src1_offset,
 
  518                            src2_data, src2_offset, dst_data, dst_offset,
 
  521                        return std::make_pair(
 
  522                            dpctl::utils::keep_args_alive(
 
  523                                exec_q, {src1, src2, dst}, host_tasks),
 
  532    auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid];
 
  534    if (strided_fn == 
nullptr) {
 
  535        throw std::runtime_error(
 
  536            "Strided implementation is missing for src1_typeid=" +
 
  537            std::to_string(src1_typeid) +
 
  538            " and src2_typeid=" + std::to_string(src2_typeid));
 
  541    using dpctl::tensor::offset_utils::device_allocate_and_pack;
 
  542    auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
 
  543        exec_q, host_tasks, simplified_shape, simplified_src1_strides,
 
  544        simplified_src2_strides, simplified_dst_strides);
 
  545    auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
 
  546    auto ©_shape_ev = std::get<2>(ptr_sz_event_triple_);
 
  548    const py::ssize_t *shape_strides = shape_strides_owner.get();
 
  550    sycl::event strided_fn_ev = strided_fn(
 
  551        exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
 
  552        src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
 
  555    sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
 
  556        exec_q, {strided_fn_ev}, shape_strides_owner);
 
  558    host_tasks.push_back(tmp_cleanup_ev);
 
  560    return std::make_pair(
 
  561        dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks),
 
  566template <
typename output_typesT>
 
  567py::object py_binary_ufunc_result_type(
const py::dtype &input1_dtype,
 
  568                                       const py::dtype &input2_dtype,
 
  569                                       const output_typesT &output_types_table)
 
  571    int tn1 = input1_dtype.num(); 
 
  572    int tn2 = input2_dtype.num(); 
 
  573    int src1_typeid = -1;
 
  574    int src2_typeid = -1;
 
  576    auto array_types = td_ns::usm_ndarray_types();
 
  579        src1_typeid = array_types.typenum_to_lookup_id(tn1);
 
  580        src2_typeid = array_types.typenum_to_lookup_id(tn2);
 
  581    } 
catch (
const std::exception &e) {
 
  582        throw py::value_error(e.what());
 
  585    if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 ||
 
  586        src2_typeid >= td_ns::num_types)
 
  588        throw std::runtime_error(
"binary output type lookup failed");
 
  590    int dst_typeid = output_types_table[src1_typeid][src2_typeid];
 
  592    if (dst_typeid < 0) {
 
  593        auto res = py::none();
 
  594        return py::cast<py::object>(res);
 
  597        using type_utils::_dtype_from_typenum;
 
  599        auto dst_typenum_t = 
static_cast<td_ns::typenum_t
>(dst_typeid);
 
  600        auto dt = _dtype_from_typenum(dst_typenum_t);
 
  602        return py::cast<py::object>(dt);
 
  608template <
typename output_typesT,
 
  609          typename contig_dispatchT,
 
  610          typename strided_dispatchT,
 
  611          typename contig_row_matrix_dispatchT>
 
  612std::pair<sycl::event, sycl::event>
 
  613    py_binary_inplace_ufunc(
const dpctl::tensor::usm_ndarray &lhs,
 
  614                            const dpctl::tensor::usm_ndarray &rhs,
 
  616                            const std::vector<sycl::event> depends,
 
  618                            const output_typesT &output_type_table,
 
  619                            const contig_dispatchT &contig_dispatch_table,
 
  620                            const strided_dispatchT &strided_dispatch_table,
 
  621                            const contig_row_matrix_dispatchT
 
  622                                &contig_row_matrix_broadcast_dispatch_table)
 
  624    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
 
  627    int rhs_typenum = rhs.get_typenum();
 
  628    int lhs_typenum = lhs.get_typenum();
 
  630    auto array_types = td_ns::usm_ndarray_types();
 
  631    int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
 
  632    int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum);
 
  634    int output_typeid = output_type_table[rhs_typeid][lhs_typeid];
 
  636    if (output_typeid != lhs_typeid) {
 
  637        throw py::value_error(
 
  638            "Left-hand side array has unexpected elemental data type.");
 
  642    if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) {
 
  643        throw py::value_error(
 
  644            "Execution queue is not compatible with allocation queues");
 
  649    int lhs_nd = lhs.get_ndim();
 
  650    if (lhs_nd != rhs.get_ndim()) {
 
  651        throw py::value_error(
"Array dimensions are not the same.");
 
  655    const py::ssize_t *rhs_shape = rhs.get_shape_raw();
 
  656    const py::ssize_t *lhs_shape = lhs.get_shape_raw();
 
  657    bool shapes_equal(
true);
 
  658    size_t rhs_nelems(1);
 
  660    for (
int i = 0; i < lhs_nd; ++i) {
 
  661        rhs_nelems *= 
static_cast<size_t>(rhs_shape[i]);
 
  662        shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
 
  665        throw py::value_error(
"Array shapes are not the same.");
 
  669    if (rhs_nelems == 0) {
 
  670        return std::make_pair(sycl::event(), sycl::event());
 
  673    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems);
 
  676    auto const &same_logical_tensors =
 
  677        dpctl::tensor::overlap::SameLogicalTensors();
 
  678    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  679    if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) {
 
  680        throw py::value_error(
"Arrays index overlapping segments of memory");
 
  683    const char *rhs_data = rhs.get_data();
 
  684    char *lhs_data = lhs.get_data();
 
  687    bool is_rhs_c_contig = rhs.is_c_contiguous();
 
  688    bool is_rhs_f_contig = rhs.is_f_contiguous();
 
  690    bool is_lhs_c_contig = lhs.is_c_contiguous();
 
  691    bool is_lhs_f_contig = lhs.is_f_contiguous();
 
  693    bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig);
 
  694    bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig);
 
  697    if (both_c_contig || both_f_contig) {
 
  698        auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
 
  700        if (contig_fn != 
nullptr) {
 
  701            auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
 
  704                dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev});
 
  706            return std::make_pair(ht_ev, comp_ev);
 
  711    auto const &rhs_strides = rhs.get_strides_vector();
 
  712    auto const &lhs_strides = lhs.get_strides_vector();
 
  714    using shT = std::vector<py::ssize_t>;
 
  715    shT simplified_shape;
 
  716    shT simplified_rhs_strides;
 
  717    shT simplified_lhs_strides;
 
  718    py::ssize_t rhs_offset(0);
 
  719    py::ssize_t lhs_offset(0);
 
  722    const py::ssize_t *shape = rhs_shape;
 
  724    simplify_iteration_space(nd, shape, rhs_strides, lhs_strides,
 
  726                             simplified_shape, simplified_rhs_strides,
 
  727                             simplified_lhs_strides, rhs_offset, lhs_offset);
 
  729    std::vector<sycl::event> host_tasks{};
 
  731        static constexpr auto unit_stride =
 
  732            std::initializer_list<py::ssize_t>{1};
 
  734        if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) &&
 
  735            isEqual(simplified_lhs_strides, unit_stride))
 
  737            auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
 
  739            if (contig_fn != 
nullptr) {
 
  741                    contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset,
 
  742                              lhs_data, lhs_offset, depends);
 
  743                sycl::event ht_ev = dpctl::utils::keep_args_alive(
 
  744                    exec_q, {rhs, lhs}, {comp_ev});
 
  746                return std::make_pair(ht_ev, comp_ev);
 
  750            static constexpr auto one_zero_strides =
 
  751                std::initializer_list<py::ssize_t>{1, 0};
 
  752            constexpr py::ssize_t one{1};
 
  754            if (isEqual(simplified_rhs_strides, one_zero_strides) &&
 
  755                isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
 
  757                auto row_matrix_broadcast_fn =
 
  758                    contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
 
  760                if (row_matrix_broadcast_fn != 
nullptr) {
 
  761                    size_t n0 = simplified_shape[1];
 
  762                    size_t n1 = simplified_shape[0];
 
  763                    sycl::event comp_ev = row_matrix_broadcast_fn(
 
  764                        exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
 
  765                        lhs_data, lhs_offset, depends);
 
  767                    return std::make_pair(dpctl::utils::keep_args_alive(
 
  768                                              exec_q, {lhs, rhs}, host_tasks),
 
  776    auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid];
 
  778    if (strided_fn == 
nullptr) {
 
  779        throw std::runtime_error(
 
  780            "Strided implementation is missing for rhs_typeid=" +
 
  781            std::to_string(rhs_typeid) +
 
  782            " and lhs_typeid=" + std::to_string(lhs_typeid));
 
  785    using dpctl::tensor::offset_utils::device_allocate_and_pack;
 
  786    auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
 
  787        exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
 
  788        simplified_lhs_strides);
 
  789    auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
 
  790    auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
 
  792    const py::ssize_t *shape_strides = shape_strides_owner.get();
 
  794    sycl::event strided_fn_ev =
 
  795        strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
 
  796                   lhs_data, lhs_offset, depends, {copy_shape_ev});
 
  799    sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
 
  800        exec_q, {strided_fn_ev}, shape_strides_owner);
 
  802    host_tasks.push_back(tmp_cleanup_ev);
 
  804    return std::make_pair(
 
  805        dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),