26#include <pybind11/numpy.h>
27#include <pybind11/pybind11.h>
29#include "ext/common.hpp"
31#include "ext/validation_utils.hpp"
32#include "utils/memory_overlap.hpp"
34namespace td_ns = dpctl::tensor::type_dispatch;
35namespace common = ext::common;
37namespace ext::validation
39inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
40 const std::vector<array_ptr> &outputs)
42 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
43 [](
const array_ptr &arr) { return arr != nullptr; });
45 if (it != inputs.cend()) {
46 return (*it)->get_queue();
49 it = std::find_if(outputs.cbegin(), outputs.cend(),
50 [](
const array_ptr &arr) { return arr != nullptr; });
52 if (it != outputs.cend()) {
53 return (*it)->get_queue();
56 throw py::value_error(
"No input or output arrays found");
59inline std::string name_of(
const array_ptr &arr,
const array_names &names)
61 auto name_it = names.find(arr);
62 assert(name_it != names.end());
64 if (name_it != names.end())
65 return "'" + name_it->second +
"'";
70inline void check_writable(
const std::vector<array_ptr> &arrays,
71 const array_names &names)
73 for (
const auto &arr : arrays) {
74 if (arr !=
nullptr && !arr->is_writable()) {
75 throw py::value_error(name_of(arr, names) +
76 " parameter is not writable");
81inline void check_c_contig(
const std::vector<array_ptr> &arrays,
82 const array_names &names)
84 for (
const auto &arr : arrays) {
85 if (arr !=
nullptr && !arr->is_c_contiguous()) {
86 throw py::value_error(name_of(arr, names) +
87 " parameter is not c-contiguos");
92inline void check_queue(
const std::vector<array_ptr> &arrays,
93 const array_names &names,
94 const sycl::queue &exec_q)
97 std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
98 return arr != nullptr && arr->get_queue() != exec_q;
101 if (unequal_queue != arrays.cend()) {
102 throw py::value_error(
103 name_of(*unequal_queue, names) +
104 " parameter has incompatible queue with other parameters");
108inline void check_no_overlap(
const array_ptr &input,
109 const array_ptr &output,
110 const array_names &names)
112 if (input ==
nullptr || output ==
nullptr) {
116 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
118 if (overlap(*input, *output)) {
119 throw py::value_error(name_of(input, names) +
120 " has overlapping memory segments with " +
121 name_of(output, names));
125inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
126 const std::vector<array_ptr> &outputs,
127 const array_names &names)
129 for (
const auto &input : inputs) {
130 for (
const auto &output : outputs) {
131 check_no_overlap(input, output, names);
136inline void check_num_dims(
const array_ptr &arr,
138 const array_names &names)
140 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
141 if (arr !=
nullptr && arr_n_dim != ndim) {
142 throw py::value_error(
"Array " + name_of(arr, names) +
" must be " +
143 std::to_string(ndim) +
"D, but got " +
144 std::to_string(arr_n_dim) +
"D.");
148inline void check_num_dims(
const std::vector<array_ptr> &arrays,
150 const array_names &names)
152 for (
const auto &arr : arrays) {
153 check_num_dims(arr, ndim, names);
157inline void check_max_dims(
const array_ptr &arr,
158 const size_t max_ndim,
159 const array_names &names)
161 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
162 if (arr !=
nullptr && arr_n_dim > max_ndim) {
163 throw py::value_error(
164 "Array " + name_of(arr, names) +
" must have no more than " +
165 std::to_string(max_ndim) +
" dimensions, but got " +
166 std::to_string(arr_n_dim) +
" dimensions.");
170inline void check_size_at_least(
const array_ptr &arr,
172 const array_names &names)
174 size_t arr_size = arr !=
nullptr ? arr->get_size() : 0;
175 if (arr !=
nullptr && arr_size < size) {
176 throw py::value_error(
"Array " + name_of(arr, names) +
177 " must have at least " + std::to_string(size) +
178 " elements, but got " + std::to_string(arr_size) +
183inline void check_has_dtype(
const array_ptr &arr,
184 const typenum_t dtype,
185 const array_names &names)
187 if (arr ==
nullptr) {
191 auto array_types = td_ns::usm_ndarray_types();
192 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
193 int expected_type_id =
static_cast<int>(dtype);
195 if (array_type_id != expected_type_id) {
196 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
197 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
199 std::string msg =
"Array " + name_of(arr, names) +
" must have dtype " +
200 std::string(py::str(dtype_py)) +
", but got " +
201 std::string(py::str(actual_dtype));
203 throw py::value_error(msg);
207inline void check_same_dtype(
const array_ptr &arr1,
208 const array_ptr &arr2,
209 const array_names &names)
211 if (arr1 ==
nullptr || arr2 ==
nullptr) {
215 auto array_types = td_ns::usm_ndarray_types();
216 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
217 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
219 if (first_type_id != second_type_id) {
220 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
221 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
223 std::string msg =
"Arrays " + name_of(arr1, names) +
" and " +
224 name_of(arr2, names) +
225 " must have the same dtype, but got " +
226 std::string(py::str(first_dtype)) +
" and " +
227 std::string(py::str(second_dtype));
229 throw py::value_error(msg);
233inline void check_same_dtype(
const std::vector<array_ptr> &arrays,
234 const array_names &names)
236 if (arrays.empty()) {
240 const auto *first = arrays[0];
241 for (
size_t i = 1; i < arrays.size(); ++i) {
242 check_same_dtype(first, arrays[i], names);
246inline void check_same_size(
const array_ptr &arr1,
247 const array_ptr &arr2,
248 const array_names &names)
250 if (arr1 ==
nullptr || arr2 ==
nullptr) {
254 auto size1 = arr1->get_size();
255 auto size2 = arr2->get_size();
257 if (size1 != size2) {
259 "Arrays " + name_of(arr1, names) +
" and " + name_of(arr2, names) +
260 " must have the same size, but got " + std::to_string(size1) +
261 " and " + std::to_string(size2);
263 throw py::value_error(msg);
267inline void check_same_size(
const std::vector<array_ptr> &arrays,
268 const array_names &names)
270 if (arrays.empty()) {
274 auto first = arrays[0];
275 for (
size_t i = 1; i < arrays.size(); ++i) {
276 check_same_size(first, arrays[i], names);
280inline void common_checks(
const std::vector<array_ptr> &inputs,
281 const std::vector<array_ptr> &outputs,
282 const array_names &names)
284 check_writable(outputs, names);
286 check_c_contig(inputs, names);
287 check_c_contig(outputs, names);
289 auto exec_q = get_queue(inputs, outputs);
291 check_queue(inputs, names, exec_q);
292 check_queue(outputs, names, exec_q);
294 check_no_overlap(inputs, outputs, names);