29#include <pybind11/numpy.h>
30#include <pybind11/pybind11.h>
32#include "ext/common.hpp"
34#include "ext/validation_utils.hpp"
35#include "utils/memory_overlap.hpp"
37namespace td_ns = dpctl::tensor::type_dispatch;
38namespace common = ext::common;
40namespace ext::validation
42inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
43 const std::vector<array_ptr> &outputs)
45 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
46 [](
const array_ptr &arr) { return arr != nullptr; });
48 if (it != inputs.cend()) {
49 return (*it)->get_queue();
52 it = std::find_if(outputs.cbegin(), outputs.cend(),
53 [](
const array_ptr &arr) { return arr != nullptr; });
55 if (it != outputs.cend()) {
56 return (*it)->get_queue();
59 throw py::value_error(
"No input or output arrays found");
62inline std::string name_of(
const array_ptr &arr,
const array_names &names)
64 auto name_it = names.find(arr);
65 assert(name_it != names.end());
67 if (name_it != names.end())
68 return "'" + name_it->second +
"'";
73inline void check_writable(
const std::vector<array_ptr> &arrays,
74 const array_names &names)
76 for (
const auto &arr : arrays) {
77 if (arr !=
nullptr && !arr->is_writable()) {
78 throw py::value_error(name_of(arr, names) +
79 " parameter is not writable");
84inline void check_c_contig(
const std::vector<array_ptr> &arrays,
85 const array_names &names)
87 for (
const auto &arr : arrays) {
88 if (arr !=
nullptr && !arr->is_c_contiguous()) {
89 throw py::value_error(name_of(arr, names) +
90 " parameter is not c-contiguos");
95inline void check_queue(
const std::vector<array_ptr> &arrays,
96 const array_names &names,
97 const sycl::queue &exec_q)
100 std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
101 return arr != nullptr && arr->get_queue() != exec_q;
104 if (unequal_queue != arrays.cend()) {
105 throw py::value_error(
106 name_of(*unequal_queue, names) +
107 " parameter has incompatible queue with other parameters");
111inline void check_no_overlap(
const array_ptr &input,
112 const array_ptr &output,
113 const array_names &names)
115 if (input ==
nullptr || output ==
nullptr) {
119 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
120 const auto &same_logical_tensors =
121 dpctl::tensor::overlap::SameLogicalTensors();
123 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
124 throw py::value_error(name_of(input, names) +
125 " has overlapping memory segments with " +
126 name_of(output, names));
130inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
131 const std::vector<array_ptr> &outputs,
132 const array_names &names)
134 for (
const auto &input : inputs) {
135 for (
const auto &output : outputs) {
136 check_no_overlap(input, output, names);
141inline void check_num_dims(
const array_ptr &arr,
143 const array_names &names)
145 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
146 if (arr !=
nullptr && arr_n_dim != ndim) {
147 throw py::value_error(
"Array " + name_of(arr, names) +
" must be " +
148 std::to_string(ndim) +
"D, but got " +
149 std::to_string(arr_n_dim) +
"D.");
153inline void check_num_dims(
const std::vector<array_ptr> &arrays,
155 const array_names &names)
157 for (
const auto &arr : arrays) {
158 check_num_dims(arr, ndim, names);
162inline void check_max_dims(
const array_ptr &arr,
163 const size_t max_ndim,
164 const array_names &names)
166 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
167 if (arr !=
nullptr && arr_n_dim > max_ndim) {
168 throw py::value_error(
169 "Array " + name_of(arr, names) +
" must have no more than " +
170 std::to_string(max_ndim) +
" dimensions, but got " +
171 std::to_string(arr_n_dim) +
" dimensions.");
175inline void check_size_at_least(
const array_ptr &arr,
177 const array_names &names)
179 size_t arr_size = arr !=
nullptr ? arr->get_size() : 0;
180 if (arr !=
nullptr && arr_size < size) {
181 throw py::value_error(
"Array " + name_of(arr, names) +
182 " must have at least " + std::to_string(size) +
183 " elements, but got " + std::to_string(arr_size) +
188inline void check_has_dtype(
const array_ptr &arr,
189 const typenum_t dtype,
190 const array_names &names)
192 if (arr ==
nullptr) {
196 auto array_types = td_ns::usm_ndarray_types();
197 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
198 int expected_type_id =
static_cast<int>(dtype);
200 if (array_type_id != expected_type_id) {
201 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
202 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
204 std::string msg =
"Array " + name_of(arr, names) +
" must have dtype " +
205 std::string(py::str(dtype_py)) +
", but got " +
206 std::string(py::str(actual_dtype));
208 throw py::value_error(msg);
212inline void check_same_dtype(
const array_ptr &arr1,
213 const array_ptr &arr2,
214 const array_names &names)
216 if (arr1 ==
nullptr || arr2 ==
nullptr) {
220 auto array_types = td_ns::usm_ndarray_types();
221 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
222 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
224 if (first_type_id != second_type_id) {
225 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
226 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
228 std::string msg =
"Arrays " + name_of(arr1, names) +
" and " +
229 name_of(arr2, names) +
230 " must have the same dtype, but got " +
231 std::string(py::str(first_dtype)) +
" and " +
232 std::string(py::str(second_dtype));
234 throw py::value_error(msg);
238inline void check_same_dtype(
const std::vector<array_ptr> &arrays,
239 const array_names &names)
241 if (arrays.empty()) {
245 const auto *first = arrays[0];
246 for (
size_t i = 1; i < arrays.size(); ++i) {
247 check_same_dtype(first, arrays[i], names);
251inline void check_same_size(
const array_ptr &arr1,
252 const array_ptr &arr2,
253 const array_names &names)
255 if (arr1 ==
nullptr || arr2 ==
nullptr) {
259 auto size1 = arr1->get_size();
260 auto size2 = arr2->get_size();
262 if (size1 != size2) {
264 "Arrays " + name_of(arr1, names) +
" and " + name_of(arr2, names) +
265 " must have the same size, but got " + std::to_string(size1) +
266 " and " + std::to_string(size2);
268 throw py::value_error(msg);
272inline void check_same_size(
const std::vector<array_ptr> &arrays,
273 const array_names &names)
275 if (arrays.empty()) {
279 auto first = arrays[0];
280 for (
size_t i = 1; i < arrays.size(); ++i) {
281 check_same_size(first, arrays[i], names);
285inline void common_checks(
const std::vector<array_ptr> &inputs,
286 const std::vector<array_ptr> &outputs,
287 const array_names &names)
289 check_writable(outputs, names);
291 check_c_contig(inputs, names);
292 check_c_contig(outputs, names);
294 auto exec_q = get_queue(inputs, outputs);
296 check_queue(inputs, names, exec_q);
297 check_queue(outputs, names, exec_q);
299 check_no_overlap(inputs, outputs, names);