DPNP C++ backend kernel library 0.21.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
evd_common.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <oneapi/mkl.hpp>
32#include <pybind11/pybind11.h>
33
34#include "dpnp4pybind11.hpp"
35
36// dpnp tensor headers
37#include "utils/type_dispatch.hpp"
38
39#include "common_helpers.hpp"
40#include "evd_common_utils.hpp"
41#include "types_matrix.hpp"
42
43namespace dpnp::extensions::lapack::evd
44{
45using dpnp::extensions::lapack::helper::check_zeros_shape;
46
47typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &,
48 const oneapi::mkl::job,
49 const oneapi::mkl::uplo,
50 const std::int64_t,
51 char *,
52 char *,
53 const std::vector<sycl::event> &);
54
55namespace dpnp_td_ns = dpnp::tensor::type_dispatch;
56namespace py = pybind11;
57
58template <typename dispatchT>
59std::pair<sycl::event, sycl::event>
60 evd_func(sycl::queue &exec_q,
61 const std::int8_t jobz,
62 const std::int8_t upper_lower,
63 const dpnp::tensor::usm_ndarray &eig_vecs,
64 const dpnp::tensor::usm_ndarray &eig_vals,
65 const std::vector<sycl::event> &depends,
66 const dispatchT &evd_dispatch_table)
67{
68 const int eig_vecs_nd = eig_vecs.get_ndim();
69
70 const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
71 const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
72
73 constexpr int expected_eig_vecs_nd = 2;
74 constexpr int expected_eig_vals_nd = 1;
75
76 common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
77 expected_eig_vecs_nd, expected_eig_vals_nd);
78
79 if (eig_vecs_shape[0] != eig_vals_shape[0]) {
80 throw py::value_error(
81 "Eigenvectors and eigenvalues have different shapes");
82 }
83
84 if (check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
85 // nothing to do
86 return std::make_pair(sycl::event(), sycl::event());
87 }
88
89 auto array_types = dpnp_td_ns::usm_ndarray_types();
90 const int eig_vecs_type_id =
91 array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
92 const int eig_vals_type_id =
93 array_types.typenum_to_lookup_id(eig_vals.get_typenum());
94
95 evd_impl_fn_ptr_t evd_fn =
96 evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
97 if (evd_fn == nullptr) {
98 throw py::value_error(
99 "No evd implementation is available for the specified data type "
100 "of the input and output arrays.");
101 }
102
103 char *eig_vecs_data = eig_vecs.get_data();
104 char *eig_vals_data = eig_vals.get_data();
105
106 const std::int64_t n = eig_vecs_shape[0];
107 const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
108 const oneapi::mkl::uplo uplo_val =
109 static_cast<oneapi::mkl::uplo>(upper_lower);
110
111 sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data,
112 eig_vals_data, depends);
113
114 sycl::event ht_ev =
115 dpnp::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, {evd_ev});
116
117 return std::make_pair(ht_ev, evd_ev);
118}
119} // namespace dpnp::extensions::lapack::evd