28#include <oneapi/mkl.hpp> 
   29#include <pybind11/pybind11.h> 
   32#include "utils/type_dispatch.hpp" 
   34#include "common_helpers.hpp" 
   35#include "evd_common_utils.hpp" 
   36#include "types_matrix.hpp" 
   38namespace dpnp::extensions::lapack::evd
 
   40typedef sycl::event (*evd_batch_impl_fn_ptr_t)(
 
   42    const oneapi::mkl::job,
 
   43    const oneapi::mkl::uplo,
 
   48    const std::vector<sycl::event> &);
 
   50namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
 
   51namespace py = pybind11;
 
   53template <
typename dispatchT>
 
   54std::pair<sycl::event, sycl::event>
 
   55    evd_batch_func(sycl::queue &exec_q,
 
   56                   const std::int8_t jobz,
 
   57                   const std::int8_t upper_lower,
 
   58                   const dpctl::tensor::usm_ndarray &eig_vecs,
 
   59                   const dpctl::tensor::usm_ndarray &eig_vals,
 
   60                   const std::vector<sycl::event> &depends,
 
   61                   const dispatchT &evd_batch_dispatch_table)
 
   63    const int eig_vecs_nd = eig_vecs.get_ndim();
 
   65    const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
 
   66    const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
 
   68    constexpr int expected_eig_vecs_nd = 3;
 
   69    constexpr int expected_eig_vals_nd = 2;
 
   71    common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
 
   72                      expected_eig_vecs_nd, expected_eig_vals_nd);
 
   74    if (eig_vecs_shape[2] != eig_vals_shape[0] ||
 
   75        eig_vecs_shape[0] != eig_vals_shape[1])
 
   77        throw py::value_error(
 
   78            "The shape of 'eig_vals' must be (batch_size, n), " 
   79            "where batch_size = " +
 
   80            std::to_string(eig_vecs_shape[0]) +
 
   81            " and n = " + std::to_string(eig_vecs_shape[1]));
 
   85    if (helper::check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
 
   87        return std::make_pair(sycl::event(), sycl::event());
 
   90    auto array_types = dpctl_td_ns::usm_ndarray_types();
 
   91    const int eig_vecs_type_id =
 
   92        array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
 
   93    const int eig_vals_type_id =
 
   94        array_types.typenum_to_lookup_id(eig_vals.get_typenum());
 
   96    evd_batch_impl_fn_ptr_t evd_batch_fn =
 
   97        evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
 
   98    if (evd_batch_fn == 
nullptr) {
 
   99        throw py::value_error(
 
  100            "No evd_batch implementation is available for the specified data " 
  101            "type of the input and output arrays.");
 
  104    char *eig_vecs_data = eig_vecs.get_data();
 
  105    char *eig_vals_data = eig_vals.get_data();
 
  107    const std::int64_t batch_size = eig_vecs_shape[2];
 
  108    const std::int64_t n = eig_vecs_shape[1];
 
  110    const oneapi::mkl::job jobz_val = 
static_cast<oneapi::mkl::job
>(jobz);
 
  111    const oneapi::mkl::uplo uplo_val =
 
  112        static_cast<oneapi::mkl::uplo
>(upper_lower);
 
  114    sycl::event evd_batch_ev =
 
  115        evd_batch_fn(exec_q, jobz_val, uplo_val, batch_size, n, eig_vecs_data,
 
  116                     eig_vals_data, depends);
 
  118    sycl::event ht_ev = dpctl::utils::keep_args_alive(
 
  119        exec_q, {eig_vecs, eig_vals}, {evd_batch_ev});
 
  121    return std::make_pair(ht_ev, evd_batch_ev);