DPNP C++ backend kernel library 0.18.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
validation_utils_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#include "ext/validation_utils.hpp"
27#include "utils/memory_overlap.hpp"
28
29namespace ext::validation
30{
31inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
32 const std::vector<array_ptr> &outputs)
33{
34 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
35 [](const array_ptr &arr) { return arr != nullptr; });
36
37 if (it != inputs.cend()) {
38 return (*it)->get_queue();
39 }
40
41 it = std::find_if(outputs.cbegin(), outputs.cend(),
42 [](const array_ptr &arr) { return arr != nullptr; });
43
44 if (it != outputs.cend()) {
45 return (*it)->get_queue();
46 }
47
48 throw py::value_error("No input or output arrays found");
49}
50
51inline std::string name_of(const array_ptr &arr, const array_names &names)
52{
53 auto name_it = names.find(arr);
54 assert(name_it != names.end());
55
56 if (name_it != names.end())
57 return "'" + name_it->second + "'";
58
59 return "'unknown'";
60}
61
62inline void check_writable(const std::vector<array_ptr> &arrays,
63 const array_names &names)
64{
65 for (const auto &arr : arrays) {
66 if (arr != nullptr && !arr->is_writable()) {
67 throw py::value_error(name_of(arr, names) +
68 " parameter is not writable");
69 }
70 }
71}
72
73inline void check_c_contig(const std::vector<array_ptr> &arrays,
74 const array_names &names)
75{
76 for (const auto &arr : arrays) {
77 if (arr != nullptr && !arr->is_c_contiguous()) {
78 throw py::value_error(name_of(arr, names) +
79 " parameter is not c-contiguos");
80 }
81 }
82}
83
84inline void check_queue(const std::vector<array_ptr> &arrays,
85 const array_names &names,
86 const sycl::queue &exec_q)
87{
88 auto unequal_queue =
89 std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
90 return arr != nullptr && arr->get_queue() != exec_q;
91 });
92
93 if (unequal_queue != arrays.cend()) {
94 throw py::value_error(
95 name_of(*unequal_queue, names) +
96 " parameter has incompatible queue with other parameters");
97 }
98}
99
100inline void check_no_overlap(const array_ptr &input,
101 const array_ptr &output,
102 const array_names &names)
103{
104 if (input == nullptr || output == nullptr) {
105 return;
106 }
107
108 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
109
110 if (overlap(*input, *output)) {
111 throw py::value_error(name_of(input, names) +
112 " has overlapping memory segments with " +
113 name_of(output, names));
114 }
115}
116
117inline void check_no_overlap(const std::vector<array_ptr> &inputs,
118 const std::vector<array_ptr> &outputs,
119 const array_names &names)
120{
121 for (const auto &input : inputs) {
122 for (const auto &output : outputs) {
123 check_no_overlap(input, output, names);
124 }
125 }
126}
127
128inline void check_num_dims(const array_ptr &arr,
129 const size_t ndim,
130 const array_names &names)
131{
132 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
133 if (arr != nullptr && arr_n_dim != ndim) {
134 throw py::value_error("Array " + name_of(arr, names) + " must be " +
135 std::to_string(ndim) + "D, but got " +
136 std::to_string(arr_n_dim) + "D.");
137 }
138}
139
140inline void check_max_dims(const array_ptr &arr,
141 const size_t max_ndim,
142 const array_names &names)
143{
144 size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
145 if (arr != nullptr && arr_n_dim > max_ndim) {
146 throw py::value_error(
147 "Array " + name_of(arr, names) + " must have no more than " +
148 std::to_string(max_ndim) + " dimensions, but got " +
149 std::to_string(arr_n_dim) + " dimensions.");
150 }
151}
152
153inline void check_size_at_least(const array_ptr &arr,
154 const size_t size,
155 const array_names &names)
156{
157 size_t arr_size = arr != nullptr ? arr->get_size() : 0;
158 if (arr != nullptr && arr_size < size) {
159 throw py::value_error("Array " + name_of(arr, names) +
160 " must have at least " + std::to_string(size) +
161 " elements, but got " + std::to_string(arr_size) +
162 " elements.");
163 }
164}
165
166inline void common_checks(const std::vector<array_ptr> &inputs,
167 const std::vector<array_ptr> &outputs,
168 const array_names &names)
169{
170 check_writable(outputs, names);
171
172 check_c_contig(inputs, names);
173 check_c_contig(outputs, names);
174
175 auto exec_q = get_queue(inputs, outputs);
176
177 check_queue(inputs, names, exec_q);
178 check_queue(outputs, names, exec_q);
179
180 check_no_overlap(inputs, outputs, names);
181}
182
183} // namespace ext::validation