27#include <oneapi/mkl.hpp>
28#include <pybind11/pybind11.h>
31#include "utils/memory_overlap.hpp"
32#include "utils/output_validation.hpp"
33#include "utils/type_dispatch.hpp"
35#include "common_helpers.hpp"
37namespace dpnp::extensions::lapack::gesvd_utils
39namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
40namespace py = pybind11;
44inline oneapi::mkl::jobsvd process_job(
const std::int8_t job_val)
48 return oneapi::mkl::jobsvd::vectors;
50 return oneapi::mkl::jobsvd::somevec;
52 return oneapi::mkl::jobsvd::vectorsina;
54 return oneapi::mkl::jobsvd::novec;
56 throw std::invalid_argument(
"Unknown value for job");
60inline void common_gesvd_checks(sycl::queue &exec_q,
61 const dpctl::tensor::usm_ndarray &a_array,
62 const dpctl::tensor::usm_ndarray &out_s,
63 const dpctl::tensor::usm_ndarray &out_u,
64 const dpctl::tensor::usm_ndarray &out_vt,
65 const std::int8_t jobu_val,
66 const std::int8_t jobvt_val,
67 const int expected_a_u_vt_ndim,
68 const int expected_s_ndim)
70 const int a_array_nd = a_array.get_ndim();
71 const int out_u_array_nd = out_u.get_ndim();
72 const int out_s_array_nd = out_s.get_ndim();
73 const int out_vt_array_nd = out_vt.get_ndim();
75 if (a_array_nd != expected_a_u_vt_ndim) {
76 throw py::value_error(
77 "The input array has ndim=" + std::to_string(a_array_nd) +
78 ", but a " + std::to_string(expected_a_u_vt_ndim) +
79 "-dimensional array is expected.");
82 if (out_s_array_nd != expected_s_ndim) {
83 throw py::value_error(
"The output array of singular values has ndim=" +
84 std::to_string(out_s_array_nd) +
", but a " +
85 std::to_string(expected_s_ndim) +
86 "-dimensional array is expected.");
89 if (jobu_val ==
'N' && jobvt_val ==
'N') {
90 if (out_u_array_nd != 0) {
91 throw py::value_error(
92 "The output array of the left singular vectors has ndim=" +
93 std::to_string(out_u_array_nd) +
94 ", but it is not used and should have ndim=0.");
96 if (out_vt_array_nd != 0) {
97 throw py::value_error(
98 "The output array of the right singular vectors has ndim=" +
99 std::to_string(out_vt_array_nd) +
100 ", but it is not used and should have ndim=0.");
104 if (out_u_array_nd != expected_a_u_vt_ndim) {
105 throw py::value_error(
106 "The output array of the left singular vectors has ndim=" +
107 std::to_string(out_u_array_nd) +
", but a " +
108 std::to_string(expected_a_u_vt_ndim) +
109 "-dimensional array is expected.");
111 if (out_vt_array_nd != expected_a_u_vt_ndim) {
112 throw py::value_error(
113 "The output array of the right singular vectors has ndim=" +
114 std::to_string(out_vt_array_nd) +
", but a " +
115 std::to_string(expected_a_u_vt_ndim) +
116 "-dimensional array is expected.");
121 if (!dpctl::utils::queues_are_compatible(exec_q,
122 {a_array, out_s, out_u, out_vt}))
124 throw py::value_error(
125 "Execution queue is not compatible with allocation queues.");
128 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
129 if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
130 overlap(a_array, out_vt) || overlap(out_s, out_u) ||
131 overlap(out_s, out_vt) || overlap(out_u, out_vt))
133 throw py::value_error(
"Arrays have overlapping segments of memory");
136 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(a_array);
137 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_s);
138 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_u);
139 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out_vt);
141 const bool is_a_array_f_contig = a_array.is_f_contiguous();
142 if (!is_a_array_f_contig) {
143 throw py::value_error(
"The input array must be F-contiguous");
146 const bool is_out_u_array_f_contig = out_u.is_f_contiguous();
147 const bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
149 if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
150 throw py::value_error(
"The output arrays of the left and right "
151 "singular vectors must be F-contiguous");
154 const bool is_out_s_array_c_contig = out_s.is_c_contiguous();
156 if (!is_out_s_array_c_contig) {
157 throw py::value_error(
"The output array of singular values "
158 "must be C-contiguous");
161 auto array_types = dpctl_td_ns::usm_ndarray_types();
162 const int a_array_type_id =
163 array_types.typenum_to_lookup_id(a_array.get_typenum());
164 const int out_u_type_id =
165 array_types.typenum_to_lookup_id(out_u.get_typenum());
166 const int out_vt_type_id =
167 array_types.typenum_to_lookup_id(out_vt.get_typenum());
169 if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
170 throw py::type_error(
171 "Input array, output left singular vectors array, "
172 "and outpuy right singular vectors array must have "
173 "the same data type");
178inline bool check_zeros_shape_gesvd(
const dpctl::tensor::usm_ndarray &a_array,
179 const dpctl::tensor::usm_ndarray &out_s,
180 const dpctl::tensor::usm_ndarray &out_u,
181 const dpctl::tensor::usm_ndarray &out_vt,
182 const std::int8_t jobu_val,
183 const std::int8_t jobvt_val)
186 const int a_array_nd = a_array.get_ndim();
187 const int out_u_array_nd = out_u.get_ndim();
188 const int out_s_array_nd = out_s.get_ndim();
189 const int out_vt_array_nd = out_vt.get_ndim();
191 const py::ssize_t *a_array_shape = a_array.get_shape_raw();
192 const py::ssize_t *s_out_shape = out_s.get_shape_raw();
193 const py::ssize_t *u_out_shape = out_u.get_shape_raw();
194 const py::ssize_t *vt_out_shape = out_vt.get_shape_raw();
196 bool is_zeros_shape = helper::check_zeros_shape(a_array_nd, a_array_shape);
197 if (jobu_val ==
'N' && jobvt_val ==
'N') {
198 is_zeros_shape = is_zeros_shape || helper::check_zeros_shape(
199 out_vt_array_nd, vt_out_shape);
204 helper::check_zeros_shape(out_u_array_nd, u_out_shape) ||
205 helper::check_zeros_shape(out_s_array_nd, s_out_shape) ||
206 helper::check_zeros_shape(out_vt_array_nd, vt_out_shape);
209 return is_zeros_shape;
212inline void handle_lapack_exc(
const std::int64_t scratchpad_size,
213 const oneapi::mkl::lapack::exception &e,
214 std::stringstream &error_msg)
216 const std::int64_t info = e.info();
218 error_msg <<
"Parameter number " << -info <<
" had an illegal value.";
220 else if (info == scratchpad_size && e.detail() != 0) {
221 error_msg <<
"Insufficient scratchpad size. Required size is at least "
225 error_msg <<
"The algorithm computing SVD failed to converge; " << info
226 <<
" off-diagonal elements of an intermediate "
227 <<
"bidiagonal form did not converge to zero.\n";
231 <<
"Unexpected MKL exception caught during gesv() call:\nreason: "
232 << e.what() <<
"\ninfo: " << e.info();