DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
validation_utils_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <pybind11/numpy.h>
32#include <pybind11/pybind11.h>
33
34#include "ext/common.hpp"
35
36#include "ext/validation_utils.hpp"
37#include "utils/memory_overlap.hpp"
38
39namespace td_ns = dpctl::tensor::type_dispatch;
40namespace common = ext::common;
41
42namespace ext::validation
43{
44inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
45 const std::vector<array_ptr> &outputs)
46{
47 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
48 [](const array_ptr &arr) { return arr != nullptr; });
49
50 if (it != inputs.cend()) {
51 return (*it)->get_queue();
52 }
53
54 it = std::find_if(outputs.cbegin(), outputs.cend(),
55 [](const array_ptr &arr) { return arr != nullptr; });
56
57 if (it != outputs.cend()) {
58 return (*it)->get_queue();
59 }
60
61 throw py::value_error("No input or output arrays found");
62}
63
64inline std::string name_of(const array_ptr &arr, const array_names &names)
65{
66 auto name_it = names.find(arr);
67 assert(name_it != names.end());
68
69 if (name_it != names.end())
70 return "'" + name_it->second + "'";
71
72 return "'unknown'";
73}
74
75inline void check_writable(const std::vector<array_ptr> &arrays,
76 const array_names &names)
77{
78 for (const auto &arr : arrays) {
79 if (arr != nullptr && !arr->is_writable()) {
80 throw py::value_error(name_of(arr, names) +
81 " parameter is not writable");
82 }
83 }
84}
85
86inline void check_c_contig(const std::vector<array_ptr> &arrays,
87 const array_names &names)
88{
89 for (const auto &arr : arrays) {
90 if (arr != nullptr && !arr->is_c_contiguous()) {
91 throw py::value_error(name_of(arr, names) +
92 " parameter is not c-contiguos");
93 }
94 }
95}
96
97inline void check_queue(const std::vector<array_ptr> &arrays,
98 const array_names &names,
99 const sycl::queue &exec_q)
100{
101 auto unequal_queue =
102 std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
103 return arr != nullptr && arr->get_queue() != exec_q;
104 });
105
106 if (unequal_queue != arrays.cend()) {
107 throw py::value_error(
108 name_of(*unequal_queue, names) +
109 " parameter has incompatible queue with other parameters");
110 }
111}
112
113inline void check_no_overlap(const array_ptr &input,
114 const array_ptr &output,
115 const array_names &names)
116{
117 if (input == nullptr || output == nullptr) {
118 return;
119 }
120
121 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
122 const auto &same_logical_tensors =
123 dpctl::tensor::overlap::SameLogicalTensors();
124
125 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
126 throw py::value_error(name_of(input, names) +
127 " has overlapping memory segments with " +
128 name_of(output, names));
129 }
130}
131
132inline void check_no_overlap(const std::vector<array_ptr> &inputs,
133 const std::vector<array_ptr> &outputs,
134 const array_names &names)
135{
136 for (const auto &input : inputs) {
137 for (const auto &output : outputs) {
138 check_no_overlap(input, output, names);
139 }
140 }
141}
142
143inline void check_num_dims(const array_ptr &arr,
144 const size_t ndim,
145 const array_names &names)
146{
147 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
148 if (arr != nullptr && arr_n_dim != ndim) {
149 throw py::value_error("Array " + name_of(arr, names) + " must be " +
150 std::to_string(ndim) + "D, but got " +
151 std::to_string(arr_n_dim) + "D.");
152 }
153}
154
155inline void check_num_dims(const std::vector<array_ptr> &arrays,
156 const size_t ndim,
157 const array_names &names)
158{
159 for (const auto &arr : arrays) {
160 check_num_dims(arr, ndim, names);
161 }
162}
163
164inline void check_max_dims(const array_ptr &arr,
165 const size_t max_ndim,
166 const array_names &names)
167{
168 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
169 if (arr != nullptr && arr_n_dim > max_ndim) {
170 throw py::value_error(
171 "Array " + name_of(arr, names) + " must have no more than " +
172 std::to_string(max_ndim) + " dimensions, but got " +
173 std::to_string(arr_n_dim) + " dimensions.");
174 }
175}
176
177inline void check_size_at_least(const array_ptr &arr,
178 const size_t size,
179 const array_names &names)
180{
181 size_t arr_size = arr != nullptr ? arr->get_size() : 0;
182 if (arr != nullptr && arr_size < size) {
183 throw py::value_error("Array " + name_of(arr, names) +
184 " must have at least " + std::to_string(size) +
185 " elements, but got " + std::to_string(arr_size) +
186 " elements.");
187 }
188}
189
190inline void check_has_dtype(const array_ptr &arr,
191 const typenum_t dtype,
192 const array_names &names)
193{
194 if (arr == nullptr) {
195 return;
196 }
197
198 auto array_types = td_ns::usm_ndarray_types();
199 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
200 int expected_type_id = static_cast<int>(dtype);
201
202 if (array_type_id != expected_type_id) {
203 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
204 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
205
206 std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
207 std::string(py::str(dtype_py)) + ", but got " +
208 std::string(py::str(actual_dtype));
209
210 throw py::value_error(msg);
211 }
212}
213
214inline void check_same_dtype(const array_ptr &arr1,
215 const array_ptr &arr2,
216 const array_names &names)
217{
218 if (arr1 == nullptr || arr2 == nullptr) {
219 return;
220 }
221
222 auto array_types = td_ns::usm_ndarray_types();
223 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
224 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
225
226 if (first_type_id != second_type_id) {
227 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
228 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
229
230 std::string msg = "Arrays " + name_of(arr1, names) + " and " +
231 name_of(arr2, names) +
232 " must have the same dtype, but got " +
233 std::string(py::str(first_dtype)) + " and " +
234 std::string(py::str(second_dtype));
235
236 throw py::value_error(msg);
237 }
238}
239
240inline void check_same_dtype(const std::vector<array_ptr> &arrays,
241 const array_names &names)
242{
243 if (arrays.empty()) {
244 return;
245 }
246
247 const auto *first = arrays[0];
248 for (size_t i = 1; i < arrays.size(); ++i) {
249 check_same_dtype(first, arrays[i], names);
250 }
251}
252
253inline void check_same_size(const array_ptr &arr1,
254 const array_ptr &arr2,
255 const array_names &names)
256{
257 if (arr1 == nullptr || arr2 == nullptr) {
258 return;
259 }
260
261 auto size1 = arr1->get_size();
262 auto size2 = arr2->get_size();
263
264 if (size1 != size2) {
265 std::string msg =
266 "Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
267 " must have the same size, but got " + std::to_string(size1) +
268 " and " + std::to_string(size2);
269
270 throw py::value_error(msg);
271 }
272}
273
274inline void check_same_size(const std::vector<array_ptr> &arrays,
275 const array_names &names)
276{
277 if (arrays.empty()) {
278 return;
279 }
280
281 auto first = arrays[0];
282 for (size_t i = 1; i < arrays.size(); ++i) {
283 check_same_size(first, arrays[i], names);
284 }
285}
286
287inline void common_checks(const std::vector<array_ptr> &inputs,
288 const std::vector<array_ptr> &outputs,
289 const array_names &names)
290{
291 check_writable(outputs, names);
292
293 check_c_contig(inputs, names);
294 check_c_contig(outputs, names);
295
296 auto exec_q = get_queue(inputs, outputs);
297
298 check_queue(inputs, names, exec_q);
299 check_queue(outputs, names, exec_q);
300
301 check_no_overlap(inputs, outputs, names);
302}
303
304} // namespace ext::validation