31#include <oneapi/mkl.hpp>
32#include <pybind11/pybind11.h>
35#include "utils/type_dispatch.hpp"
37#include "common_helpers.hpp"
38#include "evd_common_utils.hpp"
39#include "types_matrix.hpp"
41namespace dpnp::extensions::lapack::evd
43using dpnp::extensions::lapack::helper::check_zeros_shape;
45typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &,
46 const oneapi::mkl::job,
47 const oneapi::mkl::uplo,
51 const std::vector<sycl::event> &);
53namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
54namespace py = pybind11;
56template <
typename dispatchT>
57std::pair<sycl::event, sycl::event>
58 evd_func(sycl::queue &exec_q,
59 const std::int8_t jobz,
60 const std::int8_t upper_lower,
61 const dpctl::tensor::usm_ndarray &eig_vecs,
62 const dpctl::tensor::usm_ndarray &eig_vals,
63 const std::vector<sycl::event> &depends,
64 const dispatchT &evd_dispatch_table)
66 const int eig_vecs_nd = eig_vecs.get_ndim();
68 const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
69 const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
71 constexpr int expected_eig_vecs_nd = 2;
72 constexpr int expected_eig_vals_nd = 1;
74 common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
75 expected_eig_vecs_nd, expected_eig_vals_nd);
77 if (eig_vecs_shape[0] != eig_vals_shape[0]) {
78 throw py::value_error(
79 "Eigenvectors and eigenvalues have different shapes");
82 if (check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
84 return std::make_pair(sycl::event(), sycl::event());
87 auto array_types = dpctl_td_ns::usm_ndarray_types();
88 const int eig_vecs_type_id =
89 array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
90 const int eig_vals_type_id =
91 array_types.typenum_to_lookup_id(eig_vals.get_typenum());
93 evd_impl_fn_ptr_t evd_fn =
94 evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
95 if (evd_fn ==
nullptr) {
96 throw py::value_error(
97 "No evd implementation is available for the specified data type "
98 "of the input and output arrays.");
101 char *eig_vecs_data = eig_vecs.get_data();
102 char *eig_vals_data = eig_vals.get_data();
104 const std::int64_t n = eig_vecs_shape[0];
105 const oneapi::mkl::job jobz_val =
static_cast<oneapi::mkl::job
>(jobz);
106 const oneapi::mkl::uplo uplo_val =
107 static_cast<oneapi::mkl::uplo
>(upper_lower);
109 sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data,
110 eig_vals_data, depends);
113 dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, {evd_ev});
115 return std::make_pair(ht_ev, evd_ev);