DPNP C++ backend kernel library 0.18.0rc1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
validation_utils_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#include <pybind11/numpy.h>
27#include <pybind11/pybind11.h>
28
29#include "ext/common.hpp"
30
31#include "ext/validation_utils.hpp"
32#include "utils/memory_overlap.hpp"
33
34namespace td_ns = dpctl::tensor::type_dispatch;
35namespace common = ext::common;
36
37namespace ext::validation
38{
39inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
40 const std::vector<array_ptr> &outputs)
41{
42 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
43 [](const array_ptr &arr) { return arr != nullptr; });
44
45 if (it != inputs.cend()) {
46 return (*it)->get_queue();
47 }
48
49 it = std::find_if(outputs.cbegin(), outputs.cend(),
50 [](const array_ptr &arr) { return arr != nullptr; });
51
52 if (it != outputs.cend()) {
53 return (*it)->get_queue();
54 }
55
56 throw py::value_error("No input or output arrays found");
57}
58
59inline std::string name_of(const array_ptr &arr, const array_names &names)
60{
61 auto name_it = names.find(arr);
62 assert(name_it != names.end());
63
64 if (name_it != names.end())
65 return "'" + name_it->second + "'";
66
67 return "'unknown'";
68}
69
70inline void check_writable(const std::vector<array_ptr> &arrays,
71 const array_names &names)
72{
73 for (const auto &arr : arrays) {
74 if (arr != nullptr && !arr->is_writable()) {
75 throw py::value_error(name_of(arr, names) +
76 " parameter is not writable");
77 }
78 }
79}
80
81inline void check_c_contig(const std::vector<array_ptr> &arrays,
82 const array_names &names)
83{
84 for (const auto &arr : arrays) {
85 if (arr != nullptr && !arr->is_c_contiguous()) {
86 throw py::value_error(name_of(arr, names) +
87 " parameter is not c-contiguos");
88 }
89 }
90}
91
92inline void check_queue(const std::vector<array_ptr> &arrays,
93 const array_names &names,
94 const sycl::queue &exec_q)
95{
96 auto unequal_queue =
97 std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
98 return arr != nullptr && arr->get_queue() != exec_q;
99 });
100
101 if (unequal_queue != arrays.cend()) {
102 throw py::value_error(
103 name_of(*unequal_queue, names) +
104 " parameter has incompatible queue with other parameters");
105 }
106}
107
108inline void check_no_overlap(const array_ptr &input,
109 const array_ptr &output,
110 const array_names &names)
111{
112 if (input == nullptr || output == nullptr) {
113 return;
114 }
115
116 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
117
118 if (overlap(*input, *output)) {
119 throw py::value_error(name_of(input, names) +
120 " has overlapping memory segments with " +
121 name_of(output, names));
122 }
123}
124
125inline void check_no_overlap(const std::vector<array_ptr> &inputs,
126 const std::vector<array_ptr> &outputs,
127 const array_names &names)
128{
129 for (const auto &input : inputs) {
130 for (const auto &output : outputs) {
131 check_no_overlap(input, output, names);
132 }
133 }
134}
135
136inline void check_num_dims(const array_ptr &arr,
137 const size_t ndim,
138 const array_names &names)
139{
140 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
141 if (arr != nullptr && arr_n_dim != ndim) {
142 throw py::value_error("Array " + name_of(arr, names) + " must be " +
143 std::to_string(ndim) + "D, but got " +
144 std::to_string(arr_n_dim) + "D.");
145 }
146}
147
148inline void check_num_dims(const std::vector<array_ptr> &arrays,
149 const size_t ndim,
150 const array_names &names)
151{
152 for (const auto &arr : arrays) {
153 check_num_dims(arr, ndim, names);
154 }
155}
156
157inline void check_max_dims(const array_ptr &arr,
158 const size_t max_ndim,
159 const array_names &names)
160{
161 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
162 if (arr != nullptr && arr_n_dim > max_ndim) {
163 throw py::value_error(
164 "Array " + name_of(arr, names) + " must have no more than " +
165 std::to_string(max_ndim) + " dimensions, but got " +
166 std::to_string(arr_n_dim) + " dimensions.");
167 }
168}
169
170inline void check_size_at_least(const array_ptr &arr,
171 const size_t size,
172 const array_names &names)
173{
174 size_t arr_size = arr != nullptr ? arr->get_size() : 0;
175 if (arr != nullptr && arr_size < size) {
176 throw py::value_error("Array " + name_of(arr, names) +
177 " must have at least " + std::to_string(size) +
178 " elements, but got " + std::to_string(arr_size) +
179 " elements.");
180 }
181}
182
183inline void check_has_dtype(const array_ptr &arr,
184 const typenum_t dtype,
185 const array_names &names)
186{
187 if (arr == nullptr) {
188 return;
189 }
190
191 auto array_types = td_ns::usm_ndarray_types();
192 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
193 int expected_type_id = static_cast<int>(dtype);
194
195 if (array_type_id != expected_type_id) {
196 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
197 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
198
199 std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
200 std::string(py::str(dtype_py)) + ", but got " +
201 std::string(py::str(actual_dtype));
202
203 throw py::value_error(msg);
204 }
205}
206
207inline void check_same_dtype(const array_ptr &arr1,
208 const array_ptr &arr2,
209 const array_names &names)
210{
211 if (arr1 == nullptr || arr2 == nullptr) {
212 return;
213 }
214
215 auto array_types = td_ns::usm_ndarray_types();
216 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
217 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
218
219 if (first_type_id != second_type_id) {
220 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
221 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
222
223 std::string msg = "Arrays " + name_of(arr1, names) + " and " +
224 name_of(arr2, names) +
225 " must have the same dtype, but got " +
226 std::string(py::str(first_dtype)) + " and " +
227 std::string(py::str(second_dtype));
228
229 throw py::value_error(msg);
230 }
231}
232
233inline void check_same_dtype(const std::vector<array_ptr> &arrays,
234 const array_names &names)
235{
236 if (arrays.empty()) {
237 return;
238 }
239
240 const auto *first = arrays[0];
241 for (size_t i = 1; i < arrays.size(); ++i) {
242 check_same_dtype(first, arrays[i], names);
243 }
244}
245
246inline void check_same_size(const array_ptr &arr1,
247 const array_ptr &arr2,
248 const array_names &names)
249{
250 if (arr1 == nullptr || arr2 == nullptr) {
251 return;
252 }
253
254 auto size1 = arr1->get_size();
255 auto size2 = arr2->get_size();
256
257 if (size1 != size2) {
258 std::string msg =
259 "Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
260 " must have the same size, but got " + std::to_string(size1) +
261 " and " + std::to_string(size2);
262
263 throw py::value_error(msg);
264 }
265}
266
267inline void check_same_size(const std::vector<array_ptr> &arrays,
268 const array_names &names)
269{
270 if (arrays.empty()) {
271 return;
272 }
273
274 auto first = arrays[0];
275 for (size_t i = 1; i < arrays.size(); ++i) {
276 check_same_size(first, arrays[i], names);
277 }
278}
279
280inline void common_checks(const std::vector<array_ptr> &inputs,
281 const std::vector<array_ptr> &outputs,
282 const array_names &names)
283{
284 check_writable(outputs, names);
285
286 check_c_contig(inputs, names);
287 check_c_contig(outputs, names);
288
289 auto exec_q = get_queue(inputs, outputs);
290
291 check_queue(inputs, names, exec_q);
292 check_queue(outputs, names, exec_q);
293
294 check_no_overlap(inputs, outputs, names);
295}
296
297} // namespace ext::validation