DPNP C++ backend kernel library 0.18.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
evd_batch_common.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <oneapi/mkl.hpp>
29#include <pybind11/pybind11.h>
30
31// dpctl tensor headers
32#include "utils/type_dispatch.hpp"
33
34#include "common_helpers.hpp"
35#include "evd_common_utils.hpp"
36#include "types_matrix.hpp"
37
38namespace dpnp::extensions::lapack::evd
39{
40typedef sycl::event (*evd_batch_impl_fn_ptr_t)(
41 sycl::queue &,
42 const oneapi::mkl::job,
43 const oneapi::mkl::uplo,
44 const std::int64_t,
45 const std::int64_t,
46 char *,
47 char *,
48 const std::vector<sycl::event> &);
49
50namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
51namespace py = pybind11;
52
53template <typename dispatchT>
54std::pair<sycl::event, sycl::event>
55 evd_batch_func(sycl::queue &exec_q,
56 const std::int8_t jobz,
57 const std::int8_t upper_lower,
58 const dpctl::tensor::usm_ndarray &eig_vecs,
59 const dpctl::tensor::usm_ndarray &eig_vals,
60 const std::vector<sycl::event> &depends,
61 const dispatchT &evd_batch_dispatch_table)
62{
63 const int eig_vecs_nd = eig_vecs.get_ndim();
64
65 const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
66 const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
67
68 constexpr int expected_eig_vecs_nd = 3;
69 constexpr int expected_eig_vals_nd = 2;
70
71 common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
72 expected_eig_vecs_nd, expected_eig_vals_nd);
73
74 if (eig_vecs_shape[2] != eig_vals_shape[0] ||
75 eig_vecs_shape[0] != eig_vals_shape[1])
76 {
77 throw py::value_error(
78 "The shape of 'eig_vals' must be (batch_size, n), "
79 "where batch_size = " +
80 std::to_string(eig_vecs_shape[0]) +
81 " and n = " + std::to_string(eig_vecs_shape[1]));
82 }
83
84 // Ensure `batch_size` and `n` are non-zero, otherwise return empty events
85 if (helper::check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
86 // nothing to do
87 return std::make_pair(sycl::event(), sycl::event());
88 }
89
90 auto array_types = dpctl_td_ns::usm_ndarray_types();
91 const int eig_vecs_type_id =
92 array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
93 const int eig_vals_type_id =
94 array_types.typenum_to_lookup_id(eig_vals.get_typenum());
95
96 evd_batch_impl_fn_ptr_t evd_batch_fn =
97 evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
98 if (evd_batch_fn == nullptr) {
99 throw py::value_error(
100 "Types of input vectors and result array are mismatched.");
101 }
102
103 char *eig_vecs_data = eig_vecs.get_data();
104 char *eig_vals_data = eig_vals.get_data();
105
106 const std::int64_t batch_size = eig_vecs_shape[2];
107 const std::int64_t n = eig_vecs_shape[1];
108
109 const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
110 const oneapi::mkl::uplo uplo_val =
111 static_cast<oneapi::mkl::uplo>(upper_lower);
112
113 sycl::event evd_batch_ev =
114 evd_batch_fn(exec_q, jobz_val, uplo_val, batch_size, n, eig_vecs_data,
115 eig_vals_data, depends);
116
117 sycl::event ht_ev = dpctl::utils::keep_args_alive(
118 exec_q, {eig_vecs, eig_vals}, {evd_batch_ev});
119
120 return std::make_pair(ht_ev, evd_batch_ev);
121}
122} // namespace dpnp::extensions::lapack::evd