28#include <pybind11/pybind11.h> 
   31#include "utils/memory_overlap.hpp" 
   32#include "utils/output_validation.hpp" 
   33#include "utils/sycl_alloc_utils.hpp" 
   34#include "utils/type_dispatch.hpp" 
   36#include "common_helpers.hpp" 
   37#include "linalg_exceptions.hpp" 
   39namespace dpnp::extensions::lapack::gesv_utils
 
   41namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
 
   42namespace py = pybind11;
 
   44inline void common_gesv_checks(sycl::queue &exec_q,
 
   45                               const dpctl::tensor::usm_ndarray &coeff_matrix,
 
   46                               const dpctl::tensor::usm_ndarray &dependent_vals,
 
   47                               const py::ssize_t *coeff_matrix_shape,
 
   48                               const py::ssize_t *dependent_vals_shape,
 
   49                               const int expected_coeff_matrix_ndim,
 
   50                               const int min_dependent_vals_ndim,
 
   51                               const int max_dependent_vals_ndim)
 
   53    const int coeff_matrix_nd = coeff_matrix.get_ndim();
 
   54    const int dependent_vals_nd = dependent_vals.get_ndim();
 
   56    if (coeff_matrix_nd != expected_coeff_matrix_ndim) {
 
   57        throw py::value_error(
"The coefficient matrix has ndim=" +
 
   58                              std::to_string(coeff_matrix_nd) + 
", but a " +
 
   59                              std::to_string(expected_coeff_matrix_ndim) +
 
   60                              "-dimensional array is expected.");
 
   63    if (dependent_vals_nd < min_dependent_vals_ndim ||
 
   64        dependent_vals_nd > max_dependent_vals_ndim)
 
   66        throw py::value_error(
"The dependent values array has ndim=" +
 
   67                              std::to_string(dependent_vals_nd) + 
", but a " +
 
   68                              std::to_string(min_dependent_vals_ndim) +
 
   69                              "-dimensional or a " +
 
   70                              std::to_string(max_dependent_vals_ndim) +
 
   71                              "-dimensional array is expected.");
 
   80    if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
 
   81        throw py::value_error(
"The coefficient matrix must be square," 
   82                              " but got a shape of (" +
 
   83                              std::to_string(coeff_matrix_shape[0]) + 
", " +
 
   84                              std::to_string(coeff_matrix_shape[1]) + 
").");
 
   86    if (coeff_matrix_shape[0] != dependent_vals_shape[0]) {
 
   87        throw py::value_error(
"The first dimension (n) of coeff_matrix and" 
   88                              " dependent_vals must be the same, but got " +
 
   89                              std::to_string(coeff_matrix_shape[0]) + 
" and " +
 
   90                              std::to_string(dependent_vals_shape[0]) + 
".");
 
   94    if (!dpctl::utils::queues_are_compatible(exec_q,
 
   95                                             {coeff_matrix, dependent_vals}))
 
   97        throw py::value_error(
 
   98            "Execution queue is not compatible with allocation queues.");
 
  101    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  102    if (overlap(coeff_matrix, dependent_vals)) {
 
  103        throw py::value_error(
 
  104            "The arrays of coefficients and dependent variables " 
  105            "are overlapping segments of memory.");
 
  108    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(
 
  111    const bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
 
  112    if (!is_coeff_matrix_f_contig) {
 
  113        throw py::value_error(
"The coefficient matrix " 
  114                              "must be F-contiguous.");
 
  117    const bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
 
  118    if (!is_dependent_vals_f_contig) {
 
  119        throw py::value_error(
"The array of dependent variables " 
  120                              "must be F-contiguous.");
 
  123    auto array_types = dpctl_td_ns::usm_ndarray_types();
 
  124    const int coeff_matrix_type_id =
 
  125        array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
 
  126    const int dependent_vals_type_id =
 
  127        array_types.typenum_to_lookup_id(dependent_vals.get_typenum());
 
  129    if (coeff_matrix_type_id != dependent_vals_type_id) {
 
  130        throw py::value_error(
"The types of the coefficient matrix and " 
  131                              "dependent variables are mismatched.");
 
  136inline void handle_lapack_exc(sycl::queue &exec_q,
 
  137                              const std::int64_t lda,
 
  139                              std::int64_t scratchpad_size,
 
  142                              const oneapi::mkl::lapack::exception &e,
 
  143                              std::stringstream &error_msg)
 
  145    std::int64_t info = e.info();
 
  147        error_msg << 
"Parameter number " << -info << 
" had an illegal value.";
 
  149    else if (info == scratchpad_size && e.detail() != 0) {
 
  150        error_msg << 
"Insufficient scratchpad size. Required size is at least " 
  155        exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], 
sizeof(T))
 
  158        using ThresholdType = 
typename helper::value_type_of<T>::type;
 
  160        const auto threshold =
 
  161            std::numeric_limits<ThresholdType>::epsilon() * 100;
 
  162        if (std::abs(host_U) < threshold) {
 
  163            using dpctl::tensor::alloc_utils::sycl_free_noexcept;
 
  165            if (scratchpad != 
nullptr)
 
  166                sycl_free_noexcept(scratchpad, exec_q);
 
  168                sycl_free_noexcept(ipiv, exec_q);
 
  169            throw LinAlgError(
"The input coefficient matrix is singular.");
 
  172            error_msg << 
"Unexpected MKL exception caught during gesv() " 
  174                      << e.what() << 
"\ninfo: " << e.info();
 
  179            << 
"Unexpected MKL exception caught during gesv() call:\nreason: " 
  180            << e.what() << 
"\ninfo: " << e.info();