28#include <oneapi/mkl.hpp> 
   29#include <pybind11/pybind11.h> 
   32#include "utils/memory_overlap.hpp" 
   33#include "utils/output_validation.hpp" 
   34#include "utils/type_dispatch.hpp" 
   35#include "utils/type_utils.hpp" 
   37#include "types_matrix.hpp" 
   39namespace dpnp::extensions::blas::dot
 
   41typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
 
   48                                         const std::vector<sycl::event> &);
 
   50namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
 
   51namespace py = pybind11;
 
   53std::pair<sycl::event, sycl::event>
 
   54    dot_func(sycl::queue &exec_q,
 
   55             const dpctl::tensor::usm_ndarray &vectorX,
 
   56             const dpctl::tensor::usm_ndarray &vectorY,
 
   57             const dpctl::tensor::usm_ndarray &result,
 
   58             const std::vector<sycl::event> &depends,
 
   59             const dot_impl_fn_ptr_t *dot_dispatch_vector)
 
   61    const int vectorX_nd = vectorX.get_ndim();
 
   62    const int vectorY_nd = vectorY.get_ndim();
 
   63    const int result_nd = result.get_ndim();
 
   65    if ((vectorX_nd != 1)) {
 
   66        throw py::value_error(
 
   67            "The first input array has ndim=" + std::to_string(vectorX_nd) +
 
   68            ", but a 1-dimensional array is expected.");
 
   71    if ((vectorY_nd != 1)) {
 
   72        throw py::value_error(
 
   73            "The second input array has ndim=" + std::to_string(vectorY_nd) +
 
   74            ", but a 1-dimensional array is expected.");
 
   77    if ((result_nd != 0)) {
 
   78        throw py::value_error(
 
   79            "The output array has ndim=" + std::to_string(result_nd) +
 
   80            ", but a 0-dimensional array is expected.");
 
   83    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
   84    if (overlap(vectorX, result)) {
 
   85        throw py::value_error(
 
   86            "The first input array and output array are overlapping " 
   87            "segments of memory");
 
   89    if (overlap(vectorY, result)) {
 
   90        throw py::value_error(
 
   91            "The second input array and output array are overlapping " 
   92            "segments of memory");
 
   95    if (!dpctl::utils::queues_are_compatible(
 
   97            {vectorX.get_queue(), vectorY.get_queue(), result.get_queue()}))
 
   99        throw py::value_error(
 
  100            "USM allocations are not compatible with the execution queue.");
 
  103    const int src_nelems = 1;
 
  104    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
 
  105    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(result,
 
  108    const py::ssize_t x_size = vectorX.get_size();
 
  109    const py::ssize_t y_size = vectorY.get_size();
 
  110    const std::int64_t n = x_size;
 
  111    if (x_size != y_size) {
 
  112        throw py::value_error(
"The size of the first input array must be " 
  113                              "equal to the size of the second input array.");
 
  116    const int vectorX_typenum = vectorX.get_typenum();
 
  117    const int vectorY_typenum = vectorY.get_typenum();
 
  118    const int result_typenum = result.get_typenum();
 
  120    if (result_typenum != vectorX_typenum || result_typenum != vectorY_typenum)
 
  122        throw py::value_error(
"Given arrays must be of the same type.");
 
  125    auto array_types = dpctl_td_ns::usm_ndarray_types();
 
  126    const int type_id = array_types.typenum_to_lookup_id(vectorX_typenum);
 
  128    dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
 
  129    if (dot_fn == 
nullptr) {
 
  130        throw py::value_error(
 
  131            "No dot implementation is available for the specified data type " 
  132            "of the input and output arrays.");
 
  135    char *x_typeless_ptr = vectorX.get_data();
 
  136    char *y_typeless_ptr = vectorY.get_data();
 
  137    char *r_typeless_ptr = result.get_data();
 
  139    const std::vector<py::ssize_t> x_stride = vectorX.get_strides_vector();
 
  140    const std::vector<py::ssize_t> y_stride = vectorY.get_strides_vector();
 
  141    const int x_elemsize = vectorX.get_elemsize();
 
  142    const int y_elemsize = vectorY.get_elemsize();
 
  144    const std::int64_t incx = x_stride[0];
 
  145    const std::int64_t incy = y_stride[0];
 
  154        x_typeless_ptr -= (n - 1) * std::abs(incx) * x_elemsize;
 
  157        y_typeless_ptr -= (n - 1) * std::abs(incy) * y_elemsize;
 
  160    sycl::event dot_ev = dot_fn(exec_q, n, x_typeless_ptr, incx, y_typeless_ptr,
 
  161                                incy, r_typeless_ptr, depends);
 
  163    sycl::event args_ev = dpctl::utils::keep_args_alive(
 
  164        exec_q, {vectorX, vectorY, result}, {dot_ev});
 
  166    return std::make_pair(args_ev, dot_ev);