28#include <pybind11/pybind11.h> 
   29#include <pybind11/stl.h> 
   30#include <sycl/sycl.hpp> 
   32#include "dpctl4pybind11.hpp" 
   35#include "utils/output_validation.hpp" 
   36#include "utils/type_dispatch.hpp" 
   37#include "utils/type_utils.hpp" 
   39namespace dpnp::extensions::window
 
   42namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
 
   44namespace py = pybind11;
 
   46typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
 
   49                                       const std::vector<sycl::event> &);
 
   51template <
typename T, 
template <
typename> 
class Functor>
 
   52sycl::event window_impl(sycl::queue &exec_q,
 
   54                        const std::size_t nelems,
 
   55                        const std::vector<sycl::event> &depends)
 
   57    dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
 
   59    T *res = 
reinterpret_cast<T *
>(result);
 
   61    sycl::event window_ev = exec_q.submit([&](sycl::handler &cgh) {
 
   62        cgh.depends_on(depends);
 
   64        using WindowKernel = Functor<T>;
 
   65        cgh.parallel_for<WindowKernel>(sycl::range<1>(nelems),
 
   66                                       WindowKernel(res, nelems));
 
   72template <
typename funcPtrT>
 
   73std::tuple<size_t, char *, funcPtrT>
 
   74    window_fn(sycl::queue &exec_q,
 
   75              const dpctl::tensor::usm_ndarray &result,
 
   76              const funcPtrT *window_dispatch_vector)
 
   78    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
 
   80    const int nd = result.get_ndim();
 
   82        throw py::value_error(
"Array should be 1d");
 
   85    if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
 
   86        throw py::value_error(
 
   87            "Execution queue is not compatible with allocation queue.");
 
   90    const bool is_result_c_contig = result.is_c_contiguous();
 
   91    if (!is_result_c_contig) {
 
   92        throw py::value_error(
"The result array is not c-contiguous.");
 
   95    const std::size_t nelems = result.get_size();
 
   97        return std::make_tuple(nelems, 
nullptr, 
nullptr);
 
  100    const int result_typenum = result.get_typenum();
 
  101    auto array_types = dpctl_td_ns::usm_ndarray_types();
 
  102    const int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
 
  103    funcPtrT fn = window_dispatch_vector[result_type_id];
 
  106        throw std::runtime_error(
"Type of given array is not supported");
 
  109    char *result_typeless_ptr = result.get_data();
 
  110    return std::make_tuple(nelems, result_typeless_ptr, fn);
 
  113inline std::pair<sycl::event, sycl::event>
 
  114    py_window(sycl::queue &exec_q,
 
  115              const dpctl::tensor::usm_ndarray &result,
 
  116              const std::vector<sycl::event> &depends,
 
  117              const window_fn_ptr_t *window_dispatch_vector)
 
  119    auto [nelems, result_typeless_ptr, fn] =
 
  120        window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);
 
  123        return std::make_pair(sycl::event{}, sycl::event{});
 
  126    sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
 
  127    sycl::event args_ev =
 
  128        dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});
 
  130    return std::make_pair(args_ev, window_ev);