26#include <pybind11/numpy.h> 
   27#include <pybind11/pybind11.h> 
   29#include "ext/common.hpp" 
   31#include "ext/validation_utils.hpp" 
   32#include "utils/memory_overlap.hpp" 
   34namespace td_ns = dpctl::tensor::type_dispatch;
 
   35namespace common = ext::common;
 
   37namespace ext::validation
 
   39inline sycl::queue get_queue(
const std::vector<array_ptr> &inputs,
 
   40                             const std::vector<array_ptr> &outputs)
 
   42    auto it = std::find_if(inputs.cbegin(), inputs.cend(),
 
   43                           [](
const array_ptr &arr) { return arr != nullptr; });
 
   45    if (it != inputs.cend()) {
 
   46        return (*it)->get_queue();
 
   49    it = std::find_if(outputs.cbegin(), outputs.cend(),
 
   50                      [](
const array_ptr &arr) { return arr != nullptr; });
 
   52    if (it != outputs.cend()) {
 
   53        return (*it)->get_queue();
 
   56    throw py::value_error(
"No input or output arrays found");
 
   59inline std::string name_of(
const array_ptr &arr, 
const array_names &names)
 
   61    auto name_it = names.find(arr);
 
   62    assert(name_it != names.end());
 
   64    if (name_it != names.end())
 
   65        return "'" + name_it->second + 
"'";
 
   70inline void check_writable(
const std::vector<array_ptr> &arrays,
 
   71                           const array_names &names)
 
   73    for (
const auto &arr : arrays) {
 
   74        if (arr != 
nullptr && !arr->is_writable()) {
 
   75            throw py::value_error(name_of(arr, names) +
 
   76                                  " parameter is not writable");
 
   81inline void check_c_contig(
const std::vector<array_ptr> &arrays,
 
   82                           const array_names &names)
 
   84    for (
const auto &arr : arrays) {
 
   85        if (arr != 
nullptr && !arr->is_c_contiguous()) {
 
   86            throw py::value_error(name_of(arr, names) +
 
   87                                  " parameter is not c-contiguos");
 
   92inline void check_queue(
const std::vector<array_ptr> &arrays,
 
   93                        const array_names &names,
 
   94                        const sycl::queue &exec_q)
 
   97        std::find_if(arrays.cbegin(), arrays.cend(), [&](
const array_ptr &arr) {
 
   98            return arr != nullptr && arr->get_queue() != exec_q;
 
  101    if (unequal_queue != arrays.cend()) {
 
  102        throw py::value_error(
 
  103            name_of(*unequal_queue, names) +
 
  104            " parameter has incompatible queue with other parameters");
 
  108inline void check_no_overlap(
const array_ptr &input,
 
  109                             const array_ptr &output,
 
  110                             const array_names &names)
 
  112    if (input == 
nullptr || output == 
nullptr) {
 
  116    const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
 
  117    const auto &same_logical_tensors =
 
  118        dpctl::tensor::overlap::SameLogicalTensors();
 
  120    if (overlap(*input, *output) && !same_logical_tensors(*input, *output)) {
 
  121        throw py::value_error(name_of(input, names) +
 
  122                              " has overlapping memory segments with " +
 
  123                              name_of(output, names));
 
  127inline void check_no_overlap(
const std::vector<array_ptr> &inputs,
 
  128                             const std::vector<array_ptr> &outputs,
 
  129                             const array_names &names)
 
  131    for (
const auto &input : inputs) {
 
  132        for (
const auto &output : outputs) {
 
  133            check_no_overlap(input, output, names);
 
  138inline void check_num_dims(
const array_ptr &arr,
 
  140                           const array_names &names)
 
  142    size_t arr_n_dim = arr != 
nullptr ? arr->get_ndim() : 0;
 
  143    if (arr != 
nullptr && arr_n_dim != ndim) {
 
  144        throw py::value_error(
"Array " + name_of(arr, names) + 
" must be " +
 
  145                              std::to_string(ndim) + 
"D, but got " +
 
  146                              std::to_string(arr_n_dim) + 
"D.");
 
  150inline void check_num_dims(
const std::vector<array_ptr> &arrays,
 
  152                           const array_names &names)
 
  154    for (
const auto &arr : arrays) {
 
  155        check_num_dims(arr, ndim, names);
 
  159inline void check_max_dims(
const array_ptr &arr,
 
  160                           const size_t max_ndim,
 
  161                           const array_names &names)
 
  163    size_t arr_n_dim = arr != 
nullptr ? arr->get_ndim() : 0;
 
  164    if (arr != 
nullptr && arr_n_dim > max_ndim) {
 
  165        throw py::value_error(
 
  166            "Array " + name_of(arr, names) + 
" must have no more than " +
 
  167            std::to_string(max_ndim) + 
" dimensions, but got " +
 
  168            std::to_string(arr_n_dim) + 
" dimensions.");
 
  172inline void check_size_at_least(
const array_ptr &arr,
 
  174                                const array_names &names)
 
  176    size_t arr_size = arr != 
nullptr ? arr->get_size() : 0;
 
  177    if (arr != 
nullptr && arr_size < size) {
 
  178        throw py::value_error(
"Array " + name_of(arr, names) +
 
  179                              " must have at least " + std::to_string(size) +
 
  180                              " elements, but got " + std::to_string(arr_size) +
 
  185inline void check_has_dtype(
const array_ptr &arr,
 
  186                            const typenum_t dtype,
 
  187                            const array_names &names)
 
  189    if (arr == 
nullptr) {
 
  193    auto array_types = td_ns::usm_ndarray_types();
 
  194    int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
 
  195    int expected_type_id = 
static_cast<int>(dtype);
 
  197    if (array_type_id != expected_type_id) {
 
  198        py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
 
  199        py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
 
  201        std::string msg = 
"Array " + name_of(arr, names) + 
" must have dtype " +
 
  202                          std::string(py::str(dtype_py)) + 
", but got " +
 
  203                          std::string(py::str(actual_dtype));
 
  205        throw py::value_error(msg);
 
  209inline void check_same_dtype(
const array_ptr &arr1,
 
  210                             const array_ptr &arr2,
 
  211                             const array_names &names)
 
  213    if (arr1 == 
nullptr || arr2 == 
nullptr) {
 
  217    auto array_types = td_ns::usm_ndarray_types();
 
  218    int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
 
  219    int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
 
  221    if (first_type_id != second_type_id) {
 
  222        py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
 
  223        py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
 
  225        std::string msg = 
"Arrays " + name_of(arr1, names) + 
" and " +
 
  226                          name_of(arr2, names) +
 
  227                          " must have the same dtype, but got " +
 
  228                          std::string(py::str(first_dtype)) + 
" and " +
 
  229                          std::string(py::str(second_dtype));
 
  231        throw py::value_error(msg);
 
  235inline void check_same_dtype(
const std::vector<array_ptr> &arrays,
 
  236                             const array_names &names)
 
  238    if (arrays.empty()) {
 
  242    const auto *first = arrays[0];
 
  243    for (
size_t i = 1; i < arrays.size(); ++i) {
 
  244        check_same_dtype(first, arrays[i], names);
 
  248inline void check_same_size(
const array_ptr &arr1,
 
  249                            const array_ptr &arr2,
 
  250                            const array_names &names)
 
  252    if (arr1 == 
nullptr || arr2 == 
nullptr) {
 
  256    auto size1 = arr1->get_size();
 
  257    auto size2 = arr2->get_size();
 
  259    if (size1 != size2) {
 
  261            "Arrays " + name_of(arr1, names) + 
" and " + name_of(arr2, names) +
 
  262            " must have the same size, but got " + std::to_string(size1) +
 
  263            " and " + std::to_string(size2);
 
  265        throw py::value_error(msg);
 
  269inline void check_same_size(
const std::vector<array_ptr> &arrays,
 
  270                            const array_names &names)
 
  272    if (arrays.empty()) {
 
  276    auto first = arrays[0];
 
  277    for (
size_t i = 1; i < arrays.size(); ++i) {
 
  278        check_same_size(first, arrays[i], names);
 
  282inline void common_checks(
const std::vector<array_ptr> &inputs,
 
  283                          const std::vector<array_ptr> &outputs,
 
  284                          const array_names &names)
 
  286    check_writable(outputs, names);
 
  288    check_c_contig(inputs, names);
 
  289    check_c_contig(outputs, names);
 
  291    auto exec_q = get_queue(inputs, outputs);
 
  293    check_queue(inputs, names, exec_q);
 
  294    check_queue(outputs, names, exec_q);
 
  296    check_no_overlap(inputs, outputs, names);