31#include <oneapi/mkl.hpp>
32#include <pybind11/pybind11.h>
34#include "dpnp4pybind11.hpp"
37#include "utils/type_dispatch.hpp"
39#include "common_helpers.hpp"
40#include "evd_common_utils.hpp"
41#include "types_matrix.hpp"
43namespace dpnp::extensions::lapack::evd
45typedef sycl::event (*evd_batch_impl_fn_ptr_t)(
47 const oneapi::mkl::job,
48 const oneapi::mkl::uplo,
53 const std::vector<sycl::event> &);
55namespace dpnp_td_ns = dpnp::tensor::type_dispatch;
56namespace py = pybind11;
58template <
typename dispatchT>
59std::pair<sycl::event, sycl::event>
60 evd_batch_func(sycl::queue &exec_q,
61 const std::int8_t jobz,
62 const std::int8_t upper_lower,
63 const dpnp::tensor::usm_ndarray &eig_vecs,
64 const dpnp::tensor::usm_ndarray &eig_vals,
65 const std::vector<sycl::event> &depends,
66 const dispatchT &evd_batch_dispatch_table)
68 const int eig_vecs_nd = eig_vecs.get_ndim();
70 const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
71 const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
73 constexpr int expected_eig_vecs_nd = 3;
74 constexpr int expected_eig_vals_nd = 2;
76 common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
77 expected_eig_vecs_nd, expected_eig_vals_nd);
79 if (eig_vecs_shape[2] != eig_vals_shape[0] ||
80 eig_vecs_shape[0] != eig_vals_shape[1]) {
81 throw py::value_error(
82 "The shape of 'eig_vals' must be (batch_size, n), "
83 "where batch_size = " +
84 std::to_string(eig_vecs_shape[0]) +
85 " and n = " + std::to_string(eig_vecs_shape[1]));
89 if (helper::check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
91 return std::make_pair(sycl::event(), sycl::event());
94 auto array_types = dpnp_td_ns::usm_ndarray_types();
95 const int eig_vecs_type_id =
96 array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
97 const int eig_vals_type_id =
98 array_types.typenum_to_lookup_id(eig_vals.get_typenum());
100 evd_batch_impl_fn_ptr_t evd_batch_fn =
101 evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
102 if (evd_batch_fn ==
nullptr) {
103 throw py::value_error(
104 "No evd_batch implementation is available for the specified data "
105 "type of the input and output arrays.");
108 char *eig_vecs_data = eig_vecs.get_data();
109 char *eig_vals_data = eig_vals.get_data();
111 const std::int64_t batch_size = eig_vecs_shape[2];
112 const std::int64_t n = eig_vecs_shape[1];
114 const oneapi::mkl::job jobz_val =
static_cast<oneapi::mkl::job
>(jobz);
115 const oneapi::mkl::uplo uplo_val =
116 static_cast<oneapi::mkl::uplo
>(upper_lower);
118 sycl::event evd_batch_ev =
119 evd_batch_fn(exec_q, jobz_val, uplo_val, batch_size, n, eig_vecs_data,
120 eig_vals_data, depends);
122 sycl::event ht_ev = dpnp::utils::keep_args_alive(
123 exec_q, {eig_vecs, eig_vals}, {evd_batch_ev});
125 return std::make_pair(ht_ev, evd_batch_ev);