31#include <pybind11/numpy.h>
32#include <pybind11/pybind11.h>
34#include "ext/common.hpp"
36#include "ext/validation_utils.hpp"
37#include "utils/memory_overlap.hpp"
39namespace td_ns = dpctl::tensor::type_dispatch;
40namespace common = ext::common;
42namespace ext::validation
44inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
45 const std::vector<array_ptr> &outputs)
47 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
48 [](
const array_ptr &arr) { return arr != nullptr; });
50 if (it != inputs.cend()) {
51 return (*it)->get_queue();
54 it = std::find_if(outputs.cbegin(), outputs.cend(),
55 [](
const array_ptr &arr) { return arr != nullptr; });
57 if (it != outputs.cend()) {
58 return (*it)->get_queue();
61 throw py::value_error(
"No input or output arrays found");
64inline std::string name_of(
const array_ptr &arr,
const array_names &names)
66 auto name_it = names.find(arr);
67 assert(name_it != names.end());
69 if (name_it != names.end())
70 return "'" + name_it->second +
"'";
75inline void check_writable(
const std::vector<array_ptr> &arrays,
76 const array_names &names)
78 for (
const auto &arr : arrays) {
79 if (arr !=
nullptr && !arr->is_writable()) {
80 throw py::value_error(name_of(arr, names) +
81 " parameter is not writable");
86inline void check_c_contig(
const std::vector<array_ptr> &arrays,
87 const array_names &names)
89 for (
const auto &arr : arrays) {
90 if (arr !=
nullptr && !arr->is_c_contiguous()) {
91 throw py::value_error(name_of(arr, names) +
92 " parameter is not c-contiguos");
97inline void check_queue(
const std::vector<array_ptr> &arrays,
98 const array_names &names,
99 const sycl::queue &exec_q)
102 std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
103 return arr != nullptr && arr->get_queue() != exec_q;
106 if (unequal_queue != arrays.cend()) {
107 throw py::value_error(
108 name_of(*unequal_queue, names) +
109 " parameter has incompatible queue with other parameters");
113inline void check_no_overlap(
const array_ptr &input,
114 const array_ptr &output,
115 const array_names &names)
117 if (input ==
nullptr || output ==
nullptr) {
121 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
122 const auto &same_logical_tensors =
123 dpctl::tensor::overlap::SameLogicalTensors();
125 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
126 throw py::value_error(name_of(input, names) +
127 " has overlapping memory segments with " +
128 name_of(output, names));
132inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
133 const std::vector<array_ptr> &outputs,
134 const array_names &names)
136 for (
const auto &input : inputs) {
137 for (
const auto &output : outputs) {
138 check_no_overlap(input, output, names);
143inline void check_num_dims(
const array_ptr &arr,
145 const array_names &names)
147 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
148 if (arr !=
nullptr && arr_n_dim != ndim) {
149 throw py::value_error(
"Array " + name_of(arr, names) +
" must be " +
150 std::to_string(ndim) +
"D, but got " +
151 std::to_string(arr_n_dim) +
"D.");
155inline void check_num_dims(
const std::vector<array_ptr> &arrays,
157 const array_names &names)
159 for (
const auto &arr : arrays) {
160 check_num_dims(arr, ndim, names);
164inline void check_max_dims(
const array_ptr &arr,
165 const size_t max_ndim,
166 const array_names &names)
168 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
169 if (arr !=
nullptr && arr_n_dim > max_ndim) {
170 throw py::value_error(
171 "Array " + name_of(arr, names) +
" must have no more than " +
172 std::to_string(max_ndim) +
" dimensions, but got " +
173 std::to_string(arr_n_dim) +
" dimensions.");
177inline void check_size_at_least(
const array_ptr &arr,
179 const array_names &names)
181 size_t arr_size = arr !=
nullptr ? arr->get_size() : 0;
182 if (arr !=
nullptr && arr_size < size) {
183 throw py::value_error(
"Array " + name_of(arr, names) +
184 " must have at least " + std::to_string(size) +
185 " elements, but got " + std::to_string(arr_size) +
190inline void check_has_dtype(
const array_ptr &arr,
191 const typenum_t dtype,
192 const array_names &names)
194 if (arr ==
nullptr) {
198 auto array_types = td_ns::usm_ndarray_types();
199 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
200 int expected_type_id =
static_cast<int>(dtype);
202 if (array_type_id != expected_type_id) {
203 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
204 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
206 std::string msg =
"Array " + name_of(arr, names) +
" must have dtype " +
207 std::string(py::str(dtype_py)) +
", but got " +
208 std::string(py::str(actual_dtype));
210 throw py::value_error(msg);
214inline void check_same_dtype(
const array_ptr &arr1,
215 const array_ptr &arr2,
216 const array_names &names)
218 if (arr1 ==
nullptr || arr2 ==
nullptr) {
222 auto array_types = td_ns::usm_ndarray_types();
223 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
224 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
226 if (first_type_id != second_type_id) {
227 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
228 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
230 std::string msg =
"Arrays " + name_of(arr1, names) +
" and " +
231 name_of(arr2, names) +
232 " must have the same dtype, but got " +
233 std::string(py::str(first_dtype)) +
" and " +
234 std::string(py::str(second_dtype));
236 throw py::value_error(msg);
240inline void check_same_dtype(
const std::vector<array_ptr> &arrays,
241 const array_names &names)
243 if (arrays.empty()) {
247 const auto *first = arrays[0];
248 for (
size_t i = 1; i < arrays.size(); ++i) {
249 check_same_dtype(first, arrays[i], names);
253inline void check_same_size(
const array_ptr &arr1,
254 const array_ptr &arr2,
255 const array_names &names)
257 if (arr1 ==
nullptr || arr2 ==
nullptr) {
261 auto size1 = arr1->get_size();
262 auto size2 = arr2->get_size();
264 if (size1 != size2) {
266 "Arrays " + name_of(arr1, names) +
" and " + name_of(arr2, names) +
267 " must have the same size, but got " + std::to_string(size1) +
268 " and " + std::to_string(size2);
270 throw py::value_error(msg);
274inline void check_same_size(
const std::vector<array_ptr> &arrays,
275 const array_names &names)
277 if (arrays.empty()) {
281 auto first = arrays[0];
282 for (
size_t i = 1; i < arrays.size(); ++i) {
283 check_same_size(first, arrays[i], names);
287inline void common_checks(
const std::vector<array_ptr> &inputs,
288 const std::vector<array_ptr> &outputs,
289 const array_names &names)
291 check_writable(outputs, names);
293 check_c_contig(inputs, names);
294 check_c_contig(outputs, names);
296 auto exec_q = get_queue(inputs, outputs);
298 check_queue(inputs, names, exec_q);
299 check_queue(outputs, names, exec_q);
301 check_no_overlap(inputs, outputs, names);