DPNP C++ backend kernel library 0.19.0dev2
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
validation_utils_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#include <pybind11/numpy.h>
27#include <pybind11/pybind11.h>
28
29#include "ext/common.hpp"
30
31#include "ext/validation_utils.hpp"
32#include "utils/memory_overlap.hpp"
33
34namespace td_ns = dpctl::tensor::type_dispatch;
35namespace common = ext::common;
36
37namespace ext::validation
38{
39inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
40 const std::vector<array_ptr> &outputs)
41{
42 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
43 [](const array_ptr &arr) { return arr != nullptr; });
44
45 if (it != inputs.cend()) {
46 return (*it)->get_queue();
47 }
48
49 it = std::find_if(outputs.cbegin(), outputs.cend(),
50 [](const array_ptr &arr) { return arr != nullptr; });
51
52 if (it != outputs.cend()) {
53 return (*it)->get_queue();
54 }
55
56 throw py::value_error("No input or output arrays found");
57}
58
59inline std::string name_of(const array_ptr &arr, const array_names &names)
60{
61 auto name_it = names.find(arr);
62 assert(name_it != names.end());
63
64 if (name_it != names.end())
65 return "'" + name_it->second + "'";
66
67 return "'unknown'";
68}
69
70inline void check_writable(const std::vector<array_ptr> &arrays,
71 const array_names &names)
72{
73 for (const auto &arr : arrays) {
74 if (arr != nullptr && !arr->is_writable()) {
75 throw py::value_error(name_of(arr, names) +
76 " parameter is not writable");
77 }
78 }
79}
80
81inline void check_c_contig(const std::vector<array_ptr> &arrays,
82 const array_names &names)
83{
84 for (const auto &arr : arrays) {
85 if (arr != nullptr && !arr->is_c_contiguous()) {
86 throw py::value_error(name_of(arr, names) +
87 " parameter is not c-contiguos");
88 }
89 }
90}
91
92inline void check_queue(const std::vector<array_ptr> &arrays,
93 const array_names &names,
94 const sycl::queue &exec_q)
95{
96 auto unequal_queue =
97 std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
98 return arr != nullptr && arr->get_queue() != exec_q;
99 });
100
101 if (unequal_queue != arrays.cend()) {
102 throw py::value_error(
103 name_of(*unequal_queue, names) +
104 " parameter has incompatible queue with other parameters");
105 }
106}
107
108inline void check_no_overlap(const array_ptr &input,
109 const array_ptr &output,
110 const array_names &names)
111{
112 if (input == nullptr || output == nullptr) {
113 return;
114 }
115
116 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
117 const auto &same_logical_tensors =
118 dpctl::tensor::overlap::SameLogicalTensors();
119
120 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
121 throw py::value_error(name_of(input, names) +
122 " has overlapping memory segments with " +
123 name_of(output, names));
124 }
125}
126
127inline void check_no_overlap(const std::vector<array_ptr> &inputs,
128 const std::vector<array_ptr> &outputs,
129 const array_names &names)
130{
131 for (const auto &input : inputs) {
132 for (const auto &output : outputs) {
133 check_no_overlap(input, output, names);
134 }
135 }
136}
137
138inline void check_num_dims(const array_ptr &arr,
139 const size_t ndim,
140 const array_names &names)
141{
142 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
143 if (arr != nullptr && arr_n_dim != ndim) {
144 throw py::value_error("Array " + name_of(arr, names) + " must be " +
145 std::to_string(ndim) + "D, but got " +
146 std::to_string(arr_n_dim) + "D.");
147 }
148}
149
150inline void check_num_dims(const std::vector<array_ptr> &arrays,
151 const size_t ndim,
152 const array_names &names)
153{
154 for (const auto &arr : arrays) {
155 check_num_dims(arr, ndim, names);
156 }
157}
158
159inline void check_max_dims(const array_ptr &arr,
160 const size_t max_ndim,
161 const array_names &names)
162{
163 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
164 if (arr != nullptr && arr_n_dim > max_ndim) {
165 throw py::value_error(
166 "Array " + name_of(arr, names) + " must have no more than " +
167 std::to_string(max_ndim) + " dimensions, but got " +
168 std::to_string(arr_n_dim) + " dimensions.");
169 }
170}
171
172inline void check_size_at_least(const array_ptr &arr,
173 const size_t size,
174 const array_names &names)
175{
176 size_t arr_size = arr != nullptr ? arr->get_size() : 0;
177 if (arr != nullptr && arr_size < size) {
178 throw py::value_error("Array " + name_of(arr, names) +
179 " must have at least " + std::to_string(size) +
180 " elements, but got " + std::to_string(arr_size) +
181 " elements.");
182 }
183}
184
185inline void check_has_dtype(const array_ptr &arr,
186 const typenum_t dtype,
187 const array_names &names)
188{
189 if (arr == nullptr) {
190 return;
191 }
192
193 auto array_types = td_ns::usm_ndarray_types();
194 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
195 int expected_type_id = static_cast<int>(dtype);
196
197 if (array_type_id != expected_type_id) {
198 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
199 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
200
201 std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
202 std::string(py::str(dtype_py)) + ", but got " +
203 std::string(py::str(actual_dtype));
204
205 throw py::value_error(msg);
206 }
207}
208
209inline void check_same_dtype(const array_ptr &arr1,
210 const array_ptr &arr2,
211 const array_names &names)
212{
213 if (arr1 == nullptr || arr2 == nullptr) {
214 return;
215 }
216
217 auto array_types = td_ns::usm_ndarray_types();
218 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
219 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
220
221 if (first_type_id != second_type_id) {
222 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
223 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
224
225 std::string msg = "Arrays " + name_of(arr1, names) + " and " +
226 name_of(arr2, names) +
227 " must have the same dtype, but got " +
228 std::string(py::str(first_dtype)) + " and " +
229 std::string(py::str(second_dtype));
230
231 throw py::value_error(msg);
232 }
233}
234
235inline void check_same_dtype(const std::vector<array_ptr> &arrays,
236 const array_names &names)
237{
238 if (arrays.empty()) {
239 return;
240 }
241
242 const auto *first = arrays[0];
243 for (size_t i = 1; i < arrays.size(); ++i) {
244 check_same_dtype(first, arrays[i], names);
245 }
246}
247
248inline void check_same_size(const array_ptr &arr1,
249 const array_ptr &arr2,
250 const array_names &names)
251{
252 if (arr1 == nullptr || arr2 == nullptr) {
253 return;
254 }
255
256 auto size1 = arr1->get_size();
257 auto size2 = arr2->get_size();
258
259 if (size1 != size2) {
260 std::string msg =
261 "Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
262 " must have the same size, but got " + std::to_string(size1) +
263 " and " + std::to_string(size2);
264
265 throw py::value_error(msg);
266 }
267}
268
269inline void check_same_size(const std::vector<array_ptr> &arrays,
270 const array_names &names)
271{
272 if (arrays.empty()) {
273 return;
274 }
275
276 auto first = arrays[0];
277 for (size_t i = 1; i < arrays.size(); ++i) {
278 check_same_size(first, arrays[i], names);
279 }
280}
281
282inline void common_checks(const std::vector<array_ptr> &inputs,
283 const std::vector<array_ptr> &outputs,
284 const array_names &names)
285{
286 check_writable(outputs, names);
287
288 check_c_contig(inputs, names);
289 check_c_contig(outputs, names);
290
291 auto exec_q = get_queue(inputs, outputs);
292
293 check_queue(inputs, names, exec_q);
294 check_queue(outputs, names, exec_q);
295
296 check_no_overlap(inputs, outputs, names);
297}
298
299} // namespace ext::validation