DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
validation_utils_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#include <pybind11/numpy.h>
30#include <pybind11/pybind11.h>
31
32#include "ext/common.hpp"
33
34#include "ext/validation_utils.hpp"
35#include "utils/memory_overlap.hpp"
36
37namespace td_ns = dpctl::tensor::type_dispatch;
38namespace common = ext::common;
39
40namespace ext::validation
41{
42inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
43 const std::vector<array_ptr> &outputs)
44{
45 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
46 [](const array_ptr &arr) { return arr != nullptr; });
47
48 if (it != inputs.cend()) {
49 return (*it)->get_queue();
50 }
51
52 it = std::find_if(outputs.cbegin(), outputs.cend(),
53 [](const array_ptr &arr) { return arr != nullptr; });
54
55 if (it != outputs.cend()) {
56 return (*it)->get_queue();
57 }
58
59 throw py::value_error("No input or output arrays found");
60}
61
62inline std::string name_of(const array_ptr &arr, const array_names &names)
63{
64 auto name_it = names.find(arr);
65 assert(name_it != names.end());
66
67 if (name_it != names.end())
68 return "'" + name_it->second + "'";
69
70 return "'unknown'";
71}
72
73inline void check_writable(const std::vector<array_ptr> &arrays,
74 const array_names &names)
75{
76 for (const auto &arr : arrays) {
77 if (arr != nullptr && !arr->is_writable()) {
78 throw py::value_error(name_of(arr, names) +
79 " parameter is not writable");
80 }
81 }
82}
83
84inline void check_c_contig(const std::vector<array_ptr> &arrays,
85 const array_names &names)
86{
87 for (const auto &arr : arrays) {
88 if (arr != nullptr && !arr->is_c_contiguous()) {
89 throw py::value_error(name_of(arr, names) +
90 " parameter is not c-contiguos");
91 }
92 }
93}
94
95inline void check_queue(const std::vector<array_ptr> &arrays,
96 const array_names &names,
97 const sycl::queue &exec_q)
98{
99 auto unequal_queue =
100 std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
101 return arr != nullptr && arr->get_queue() != exec_q;
102 });
103
104 if (unequal_queue != arrays.cend()) {
105 throw py::value_error(
106 name_of(*unequal_queue, names) +
107 " parameter has incompatible queue with other parameters");
108 }
109}
110
111inline void check_no_overlap(const array_ptr &input,
112 const array_ptr &output,
113 const array_names &names)
114{
115 if (input == nullptr || output == nullptr) {
116 return;
117 }
118
119 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
120 const auto &same_logical_tensors =
121 dpctl::tensor::overlap::SameLogicalTensors();
122
123 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
124 throw py::value_error(name_of(input, names) +
125 " has overlapping memory segments with " +
126 name_of(output, names));
127 }
128}
129
130inline void check_no_overlap(const std::vector<array_ptr> &inputs,
131 const std::vector<array_ptr> &outputs,
132 const array_names &names)
133{
134 for (const auto &input : inputs) {
135 for (const auto &output : outputs) {
136 check_no_overlap(input, output, names);
137 }
138 }
139}
140
141inline void check_num_dims(const array_ptr &arr,
142 const size_t ndim,
143 const array_names &names)
144{
145 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
146 if (arr != nullptr && arr_n_dim != ndim) {
147 throw py::value_error("Array " + name_of(arr, names) + " must be " +
148 std::to_string(ndim) + "D, but got " +
149 std::to_string(arr_n_dim) + "D.");
150 }
151}
152
153inline void check_num_dims(const std::vector<array_ptr> &arrays,
154 const size_t ndim,
155 const array_names &names)
156{
157 for (const auto &arr : arrays) {
158 check_num_dims(arr, ndim, names);
159 }
160}
161
162inline void check_max_dims(const array_ptr &arr,
163 const size_t max_ndim,
164 const array_names &names)
165{
166 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
167 if (arr != nullptr && arr_n_dim > max_ndim) {
168 throw py::value_error(
169 "Array " + name_of(arr, names) + " must have no more than " +
170 std::to_string(max_ndim) + " dimensions, but got " +
171 std::to_string(arr_n_dim) + " dimensions.");
172 }
173}
174
175inline void check_size_at_least(const array_ptr &arr,
176 const size_t size,
177 const array_names &names)
178{
179 size_t arr_size = arr != nullptr ? arr->get_size() : 0;
180 if (arr != nullptr && arr_size < size) {
181 throw py::value_error("Array " + name_of(arr, names) +
182 " must have at least " + std::to_string(size) +
183 " elements, but got " + std::to_string(arr_size) +
184 " elements.");
185 }
186}
187
188inline void check_has_dtype(const array_ptr &arr,
189 const typenum_t dtype,
190 const array_names &names)
191{
192 if (arr == nullptr) {
193 return;
194 }
195
196 auto array_types = td_ns::usm_ndarray_types();
197 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
198 int expected_type_id = static_cast<int>(dtype);
199
200 if (array_type_id != expected_type_id) {
201 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
202 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
203
204 std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
205 std::string(py::str(dtype_py)) + ", but got " +
206 std::string(py::str(actual_dtype));
207
208 throw py::value_error(msg);
209 }
210}
211
212inline void check_same_dtype(const array_ptr &arr1,
213 const array_ptr &arr2,
214 const array_names &names)
215{
216 if (arr1 == nullptr || arr2 == nullptr) {
217 return;
218 }
219
220 auto array_types = td_ns::usm_ndarray_types();
221 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
222 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
223
224 if (first_type_id != second_type_id) {
225 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
226 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
227
228 std::string msg = "Arrays " + name_of(arr1, names) + " and " +
229 name_of(arr2, names) +
230 " must have the same dtype, but got " +
231 std::string(py::str(first_dtype)) + " and " +
232 std::string(py::str(second_dtype));
233
234 throw py::value_error(msg);
235 }
236}
237
238inline void check_same_dtype(const std::vector<array_ptr> &arrays,
239 const array_names &names)
240{
241 if (arrays.empty()) {
242 return;
243 }
244
245 const auto *first = arrays[0];
246 for (size_t i = 1; i < arrays.size(); ++i) {
247 check_same_dtype(first, arrays[i], names);
248 }
249}
250
251inline void check_same_size(const array_ptr &arr1,
252 const array_ptr &arr2,
253 const array_names &names)
254{
255 if (arr1 == nullptr || arr2 == nullptr) {
256 return;
257 }
258
259 auto size1 = arr1->get_size();
260 auto size2 = arr2->get_size();
261
262 if (size1 != size2) {
263 std::string msg =
264 "Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
265 " must have the same size, but got " + std::to_string(size1) +
266 " and " + std::to_string(size2);
267
268 throw py::value_error(msg);
269 }
270}
271
272inline void check_same_size(const std::vector<array_ptr> &arrays,
273 const array_names &names)
274{
275 if (arrays.empty()) {
276 return;
277 }
278
279 auto first = arrays[0];
280 for (size_t i = 1; i < arrays.size(); ++i) {
281 check_same_size(first, arrays[i], names);
282 }
283}
284
285inline void common_checks(const std::vector<array_ptr> &inputs,
286 const std::vector<array_ptr> &outputs,
287 const array_names &names)
288{
289 check_writable(outputs, names);
290
291 check_c_contig(inputs, names);
292 check_c_contig(outputs, names);
293
294 auto exec_q = get_queue(inputs, outputs);
295
296 check_queue(inputs, names, exec_q);
297 check_queue(outputs, names, exec_q);
298
299 check_no_overlap(inputs, outputs, names);
300}
301
302} // namespace ext::validation