38#include <sycl/sycl.hpp>
40#include "dpctl4pybind11.hpp"
41#include <pybind11/pybind11.h>
42#include <pybind11/stl.h>
45#include "utils/output_validation.hpp"
46#include "utils/type_dispatch.hpp"
47#include "utils/type_utils.hpp"
49namespace dpnp::extensions::window
51namespace py = pybind11;
52namespace td_ns = dpctl::tensor::type_dispatch;
54typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
57 const std::vector<sycl::event> &);
59template <
typename T,
template <
typename>
class Functor>
60sycl::event window_impl(sycl::queue &exec_q,
62 const std::size_t nelems,
63 const std::vector<sycl::event> &depends)
65 dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
67 T *res =
reinterpret_cast<T *
>(result);
69 sycl::event window_ev = exec_q.submit([&](sycl::handler &cgh) {
70 cgh.depends_on(depends);
72 using WindowKernel = Functor<T>;
73 cgh.parallel_for<WindowKernel>(sycl::range<1>(nelems),
74 WindowKernel(res, nelems));
80template <
typename fnT,
typename T,
template <
typename>
typename FunctorT>
85 if constexpr (std::is_floating_point_v<T>) {
86 return window_impl<T, FunctorT>;
94template <
typename funcPtrT>
95std::tuple<size_t, char *, funcPtrT>
96 window_fn(sycl::queue &exec_q,
97 const dpctl::tensor::usm_ndarray &result,
98 const funcPtrT *window_dispatch_vector)
100 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
102 const int nd = result.get_ndim();
104 throw py::value_error(
"Array should be 1d");
107 if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
108 throw py::value_error(
109 "Execution queue is not compatible with allocation queue.");
112 const bool is_result_c_contig = result.is_c_contiguous();
113 if (!is_result_c_contig) {
114 throw py::value_error(
"The result array is not c-contiguous.");
117 const std::size_t nelems = result.get_size();
119 return std::make_tuple(nelems,
nullptr,
nullptr);
122 const int result_typenum = result.get_typenum();
123 auto array_types = td_ns::usm_ndarray_types();
124 const int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
125 funcPtrT fn = window_dispatch_vector[result_type_id];
128 throw std::runtime_error(
"Type of given array is not supported");
131 char *result_typeless_ptr = result.get_data();
132 return std::make_tuple(nelems, result_typeless_ptr, fn);
135inline std::pair<sycl::event, sycl::event>
136 py_window(sycl::queue &exec_q,
137 const dpctl::tensor::usm_ndarray &result,
138 const std::vector<sycl::event> &depends,
139 const window_fn_ptr_t *window_dispatch_vector)
141 auto [nelems, result_typeless_ptr, fn] =
142 window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);
145 return std::make_pair(sycl::event{}, sycl::event{});
148 sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
149 sycl::event args_ev =
150 dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});
152 return std::make_pair(args_ev, window_ev);