26#include <pybind11/numpy.h>
27#include <pybind11/pybind11.h>
29#include "ext/common.hpp"
31#include "ext/validation_utils.hpp"
32#include "utils/memory_overlap.hpp"
34namespace td_ns = dpctl::tensor::type_dispatch;
35namespace common = ext::common;
37namespace ext::validation
39inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
40 const std::vector<array_ptr> &outputs)
42 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
43 [](
const array_ptr &arr) { return arr != nullptr; });
45 if (it != inputs.cend()) {
46 return (*it)->get_queue();
49 it = std::find_if(outputs.cbegin(), outputs.cend(),
50 [](
const array_ptr &arr) { return arr != nullptr; });
52 if (it != outputs.cend()) {
53 return (*it)->get_queue();
56 throw py::value_error(
"No input or output arrays found");
59inline std::string name_of(
const array_ptr &arr,
const array_names &names)
61 auto name_it = names.find(arr);
62 assert(name_it != names.end());
64 if (name_it != names.end())
65 return "'" + name_it->second +
"'";
70inline void check_writable(
const std::vector<array_ptr> &arrays,
71 const array_names &names)
73 for (
const auto &arr : arrays) {
74 if (arr !=
nullptr && !arr->is_writable()) {
75 throw py::value_error(name_of(arr, names) +
76 " parameter is not writable");
81inline void check_c_contig(
const std::vector<array_ptr> &arrays,
82 const array_names &names)
84 for (
const auto &arr : arrays) {
85 if (arr !=
nullptr && !arr->is_c_contiguous()) {
86 throw py::value_error(name_of(arr, names) +
87 " parameter is not c-contiguos");
92inline void check_queue(
const std::vector<array_ptr> &arrays,
93 const array_names &names,
94 const sycl::queue &exec_q)
97 std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
98 return arr != nullptr && arr->get_queue() != exec_q;
101 if (unequal_queue != arrays.cend()) {
102 throw py::value_error(
103 name_of(*unequal_queue, names) +
104 " parameter has incompatible queue with other parameters");
108inline void check_no_overlap(
const array_ptr &input,
109 const array_ptr &output,
110 const array_names &names)
112 if (input ==
nullptr || output ==
nullptr) {
116 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
117 const auto &same_logical_tensors =
118 dpctl::tensor::overlap::SameLogicalTensors();
120 if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
121 throw py::value_error(name_of(input, names) +
122 " has overlapping memory segments with " +
123 name_of(output, names));
127inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
128 const std::vector<array_ptr> &outputs,
129 const array_names &names)
131 for (
const auto &input : inputs) {
132 for (
const auto &output : outputs) {
133 check_no_overlap(input, output, names);
138inline void check_num_dims(
const array_ptr &arr,
140 const array_names &names)
142 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
143 if (arr !=
nullptr && arr_n_dim != ndim) {
144 throw py::value_error(
"Array " + name_of(arr, names) +
" must be " +
145 std::to_string(ndim) +
"D, but got " +
146 std::to_string(arr_n_dim) +
"D.");
150inline void check_num_dims(
const std::vector<array_ptr> &arrays,
152 const array_names &names)
154 for (
const auto &arr : arrays) {
155 check_num_dims(arr, ndim, names);
159inline void check_max_dims(
const array_ptr &arr,
160 const size_t max_ndim,
161 const array_names &names)
163 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
164 if (arr !=
nullptr && arr_n_dim > max_ndim) {
165 throw py::value_error(
166 "Array " + name_of(arr, names) +
" must have no more than " +
167 std::to_string(max_ndim) +
" dimensions, but got " +
168 std::to_string(arr_n_dim) +
" dimensions.");
172inline void check_size_at_least(
const array_ptr &arr,
174 const array_names &names)
176 size_t arr_size = arr !=
nullptr ? arr->get_size() : 0;
177 if (arr !=
nullptr && arr_size < size) {
178 throw py::value_error(
"Array " + name_of(arr, names) +
179 " must have at least " + std::to_string(size) +
180 " elements, but got " + std::to_string(arr_size) +
185inline void check_has_dtype(
const array_ptr &arr,
186 const typenum_t dtype,
187 const array_names &names)
189 if (arr ==
nullptr) {
193 auto array_types = td_ns::usm_ndarray_types();
194 int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
195 int expected_type_id =
static_cast<int>(dtype);
197 if (array_type_id != expected_type_id) {
198 py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
199 py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
201 std::string msg =
"Array " + name_of(arr, names) +
" must have dtype " +
202 std::string(py::str(dtype_py)) +
", but got " +
203 std::string(py::str(actual_dtype));
205 throw py::value_error(msg);
209inline void check_same_dtype(
const array_ptr &arr1,
210 const array_ptr &arr2,
211 const array_names &names)
213 if (arr1 ==
nullptr || arr2 ==
nullptr) {
217 auto array_types = td_ns::usm_ndarray_types();
218 int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
219 int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
221 if (first_type_id != second_type_id) {
222 py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
223 py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
225 std::string msg =
"Arrays " + name_of(arr1, names) +
" and " +
226 name_of(arr2, names) +
227 " must have the same dtype, but got " +
228 std::string(py::str(first_dtype)) +
" and " +
229 std::string(py::str(second_dtype));
231 throw py::value_error(msg);
235inline void check_same_dtype(
const std::vector<array_ptr> &arrays,
236 const array_names &names)
238 if (arrays.empty()) {
242 const auto *first = arrays[0];
243 for (
size_t i = 1; i < arrays.size(); ++i) {
244 check_same_dtype(first, arrays[i], names);
248inline void check_same_size(
const array_ptr &arr1,
249 const array_ptr &arr2,
250 const array_names &names)
252 if (arr1 ==
nullptr || arr2 ==
nullptr) {
256 auto size1 = arr1->get_size();
257 auto size2 = arr2->get_size();
259 if (size1 != size2) {
261 "Arrays " + name_of(arr1, names) +
" and " + name_of(arr2, names) +
262 " must have the same size, but got " + std::to_string(size1) +
263 " and " + std::to_string(size2);
265 throw py::value_error(msg);
269inline void check_same_size(
const std::vector<array_ptr> &arrays,
270 const array_names &names)
272 if (arrays.empty()) {
276 auto first = arrays[0];
277 for (
size_t i = 1; i < arrays.size(); ++i) {
278 check_same_size(first, arrays[i], names);
282inline void common_checks(
const std::vector<array_ptr> &inputs,
283 const std::vector<array_ptr> &outputs,
284 const array_names &names)
286 check_writable(outputs, names);
288 check_c_contig(inputs, names);
289 check_c_contig(outputs, names);
291 auto exec_q = get_queue(inputs, outputs);
293 check_queue(inputs, names, exec_q);
294 check_queue(outputs, names, exec_q);
296 check_no_overlap(inputs, outputs, names);