DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32#include <stdexcept>
33#include <tuple>
34#include <type_traits>
35#include <utility>
36#include <vector>
37
38#include <sycl/sycl.hpp>
39
40#include "dpctl4pybind11.hpp"
41#include <pybind11/pybind11.h>
42#include <pybind11/stl.h>
43
44// dpctl tensor headers
45#include "utils/output_validation.hpp"
46#include "utils/type_dispatch.hpp"
47#include "utils/type_utils.hpp"
48
49namespace dpnp::extensions::window
50{
51namespace py = pybind11;
52namespace td_ns = dpctl::tensor::type_dispatch;
53
54typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
55 char *,
56 const std::size_t,
57 const std::vector<sycl::event> &);
58
59template <typename T, template <typename> class Functor>
60sycl::event window_impl(sycl::queue &exec_q,
61 char *result,
62 const std::size_t nelems,
63 const std::vector<sycl::event> &depends)
64{
65 dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
66
67 T *res = reinterpret_cast<T *>(result);
68
69 sycl::event window_ev = exec_q.submit([&](sycl::handler &cgh) {
70 cgh.depends_on(depends);
71
72 using WindowKernel = Functor<T>;
73 cgh.parallel_for<WindowKernel>(sycl::range<1>(nelems),
74 WindowKernel(res, nelems));
75 });
76
77 return window_ev;
78}
79
80template <typename fnT, typename T, template <typename> typename FunctorT>
81struct Factory
82{
83 fnT get()
84 {
85 if constexpr (std::is_floating_point_v<T>) {
86 return window_impl<T, FunctorT>;
87 }
88 else {
89 return nullptr;
90 }
91 }
92};
93
94template <typename funcPtrT>
95std::tuple<size_t, char *, funcPtrT>
96 window_fn(sycl::queue &exec_q,
97 const dpctl::tensor::usm_ndarray &result,
98 const funcPtrT *window_dispatch_vector)
99{
100 dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
101
102 const int nd = result.get_ndim();
103 if (nd != 1) {
104 throw py::value_error("Array should be 1d");
105 }
106
107 if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
108 throw py::value_error(
109 "Execution queue is not compatible with allocation queue.");
110 }
111
112 const bool is_result_c_contig = result.is_c_contiguous();
113 if (!is_result_c_contig) {
114 throw py::value_error("The result array is not c-contiguous.");
115 }
116
117 const std::size_t nelems = result.get_size();
118 if (nelems == 0) {
119 return std::make_tuple(nelems, nullptr, nullptr);
120 }
121
122 const int result_typenum = result.get_typenum();
123 auto array_types = td_ns::usm_ndarray_types();
124 const int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
125 funcPtrT fn = window_dispatch_vector[result_type_id];
126
127 if (fn == nullptr) {
128 throw std::runtime_error("Type of given array is not supported");
129 }
130
131 char *result_typeless_ptr = result.get_data();
132 return std::make_tuple(nelems, result_typeless_ptr, fn);
133}
134
135inline std::pair<sycl::event, sycl::event>
136 py_window(sycl::queue &exec_q,
137 const dpctl::tensor::usm_ndarray &result,
138 const std::vector<sycl::event> &depends,
139 const window_fn_ptr_t *window_dispatch_vector)
140{
141 auto [nelems, result_typeless_ptr, fn] =
142 window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);
143
144 if (nelems == 0) {
145 return std::make_pair(sycl::event{}, sycl::event{});
146 }
147
148 sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
149 sycl::event args_ev =
150 dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});
151
152 return std::make_pair(args_ev, window_ev);
153}
154} // namespace dpnp::extensions::window