26#include "ext/validation_utils.hpp"
27#include "utils/memory_overlap.hpp"
29namespace ext::validation
31inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
32 const std::vector<array_ptr> &outputs)
34 auto it = std::find_if(inputs.cbegin(), inputs.cend(),
35 [](
const array_ptr &arr) { return arr != nullptr; });
37 if (it != inputs.cend()) {
38 return (*it)->get_queue();
41 it = std::find_if(outputs.cbegin(), outputs.cend(),
42 [](
const array_ptr &arr) { return arr != nullptr; });
44 if (it != outputs.cend()) {
45 return (*it)->get_queue();
48 throw py::value_error(
"No input or output arrays found");
51inline std::string name_of(
const array_ptr &arr,
const array_names &names)
53 auto name_it = names.find(arr);
54 assert(name_it != names.end());
56 if (name_it != names.end())
57 return "'" + name_it->second +
"'";
62inline void check_writable(
const std::vector<array_ptr> &arrays,
63 const array_names &names)
65 for (
const auto &arr : arrays) {
66 if (arr !=
nullptr && !arr->is_writable()) {
67 throw py::value_error(name_of(arr, names) +
68 " parameter is not writable");
73inline void check_c_contig(
const std::vector<array_ptr> &arrays,
74 const array_names &names)
76 for (
const auto &arr : arrays) {
77 if (arr !=
nullptr && !arr->is_c_contiguous()) {
78 throw py::value_error(name_of(arr, names) +
79 " parameter is not c-contiguos");
84inline void check_queue(
const std::vector<array_ptr> &arrays,
85 const array_names &names,
86 const sycl::queue &exec_q)
89 std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
90 return arr != nullptr && arr->get_queue() != exec_q;
93 if (unequal_queue != arrays.cend()) {
94 throw py::value_error(
95 name_of(*unequal_queue, names) +
96 " parameter has incompatible queue with other parameters");
100inline void check_no_overlap(
const array_ptr &input,
101 const array_ptr &output,
102 const array_names &names)
104 if (input ==
nullptr || output ==
nullptr) {
108 const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
110 if (overlap(*input, *output)) {
111 throw py::value_error(name_of(input, names) +
112 " has overlapping memory segments with " +
113 name_of(output, names));
117inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
118 const std::vector<array_ptr> &outputs,
119 const array_names &names)
121 for (
const auto &input : inputs) {
122 for (
const auto &output : outputs) {
123 check_no_overlap(input, output, names);
128inline void check_num_dims(
const array_ptr &arr,
130 const array_names &names)
132 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
133 if (arr !=
nullptr && arr_n_dim != ndim) {
134 throw py::value_error(
"Array " + name_of(arr, names) +
" must be " +
135 std::to_string(ndim) +
"D, but got " +
136 std::to_string(arr_n_dim) +
"D.");
140inline void check_max_dims(
const array_ptr &arr,
141 const size_t max_ndim,
142 const array_names &names)
144 size_t arr_n_dim = arr !=
nullptr ? arr->get_ndim() : 0;
145 if (arr !=
nullptr && arr_n_dim > max_ndim) {
146 throw py::value_error(
147 "Array " + name_of(arr, names) +
" must have no more than " +
148 std::to_string(max_ndim) +
" dimensions, but got " +
149 std::to_string(arr_n_dim) +
" dimensions.");
153inline void check_size_at_least(
const array_ptr &arr,
155 const array_names &names)
157 size_t arr_size = arr !=
nullptr ? arr->get_size() : 0;
158 if (arr !=
nullptr && arr_size < size) {
159 throw py::value_error(
"Array " + name_of(arr, names) +
160 " must have at least " + std::to_string(size) +
161 " elements, but got " + std::to_string(arr_size) +
166inline void common_checks(
const std::vector<array_ptr> &inputs,
167 const std::vector<array_ptr> &outputs,
168 const array_names &names)
170 check_writable(outputs, names);
172 check_c_contig(inputs, names);
173 check_c_contig(outputs, names);
175 auto exec_q = get_queue(inputs, outputs);
177 check_queue(inputs, names, exec_q);
178 check_queue(outputs, names, exec_q);
180 check_no_overlap(inputs, outputs, names);