DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// utils extension header
32#include "ext/common.hpp"
33
34namespace ext_ns = ext::common;
35
40#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
41 template <typename T1, typename T2, unsigned int vec_sz, \
42 unsigned int n_vecs> \
43 class __name__##_contig_kernel; \
44 \
45 template <typename argTy> \
46 sycl::event __name__##_contig_impl( \
47 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
48 const std::vector<sycl::event> &depends = {}) \
49 { \
50 return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
51 __name__##_contig_kernel>( \
52 exec_q, nelems, arg_p, res_p, depends); \
53 } \
54 \
55 template <typename fnT, typename T> \
56 struct ContigFactory \
57 { \
58 fnT get() \
59 { \
60 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
61 void>) { \
62 fnT fn = nullptr; \
63 return fn; \
64 } \
65 else { \
66 fnT fn = __name__##_contig_impl<T>; \
67 return fn; \
68 } \
69 } \
70 }; \
71 \
72 template <typename fnT, typename T> \
73 struct TypeMapFactory \
74 { \
75 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
76 { \
77 using rT = typename OutputType<T>::value_type; \
78 return td_ns::GetTypeid<rT>{}.get(); \
79 } \
80 }; \
81 \
82 template <typename T1, typename T2, typename T3> \
83 class __name__##_strided_kernel; \
84 \
85 template <typename argTy> \
86 sycl::event __name__##_strided_impl( \
87 sycl::queue &exec_q, size_t nelems, int nd, \
88 const py::ssize_t *shape_and_strides, const char *arg_p, \
89 py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
90 const std::vector<sycl::event> &depends, \
91 const std::vector<sycl::event> &additional_depends) \
92 { \
93 return ew_cmn_ns::unary_strided_impl< \
94 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
95 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
96 res_offset, depends, additional_depends); \
97 } \
98 \
99 template <typename fnT, typename T> \
100 struct StridedFactory \
101 { \
102 fnT get() \
103 { \
104 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
105 void>) { \
106 fnT fn = nullptr; \
107 return fn; \
108 } \
109 else { \
110 fnT fn = __name__##_strided_impl<T>; \
111 return fn; \
112 } \
113 } \
114 }; \
115 \
116 void populate_##__name__##_dispatch_vectors(void) \
117 { \
118 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
119 ContigFactory>( \
120 __name__##_contig_dispatch_vector); \
121 ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
122 StridedFactory>( \
123 __name__##_strided_dispatch_vector); \
124 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
125 __name__##_output_typeid_vector); \
126 };
127
132#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
133 template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
134 unsigned int n_vecs> \
135 class __name__##_contig_kernel; \
136 \
137 template <typename argTy> \
138 sycl::event __name__##_contig_impl( \
139 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
140 char *res2_p, const std::vector<sycl::event> &depends = {}) \
141 { \
142 return ew_cmn_ns::unary_two_outputs_contig_impl< \
143 argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
144 exec_q, nelems, arg_p, res1_p, res2_p, depends); \
145 } \
146 \
147 template <typename fnT, typename T> \
148 struct ContigFactory \
149 { \
150 fnT get() \
151 { \
152 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
153 void> || \
154 std::is_same_v<typename OutputType<T>::value_type2, \
155 void>) \
156 { \
157 fnT fn = nullptr; \
158 return fn; \
159 } \
160 else { \
161 fnT fn = __name__##_contig_impl<T>; \
162 return fn; \
163 } \
164 } \
165 }; \
166 \
167 template <typename fnT, typename T> \
168 struct TypeMapFactory \
169 { \
170 std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
171 std::pair<int, int>> \
172 get() \
173 { \
174 using rT1 = typename OutputType<T>::value_type1; \
175 using rT2 = typename OutputType<T>::value_type2; \
176 return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
177 td_ns::GetTypeid<rT2>{}.get()); \
178 } \
179 }; \
180 \
181 template <typename T1, typename T2, typename T3, typename T4> \
182 class __name__##_strided_kernel; \
183 \
184 template <typename argTy> \
185 sycl::event __name__##_strided_impl( \
186 sycl::queue &exec_q, size_t nelems, int nd, \
187 const py::ssize_t *shape_and_strides, const char *arg_p, \
188 py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
189 char *res2_p, py::ssize_t res2_offset, \
190 const std::vector<sycl::event> &depends, \
191 const std::vector<sycl::event> &additional_depends) \
192 { \
193 return ew_cmn_ns::unary_two_outputs_strided_impl< \
194 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
195 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
196 res1_offset, res2_p, res2_offset, depends, additional_depends); \
197 } \
198 \
199 template <typename fnT, typename T> \
200 struct StridedFactory \
201 { \
202 fnT get() \
203 { \
204 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
205 void> || \
206 std::is_same_v<typename OutputType<T>::value_type2, \
207 void>) \
208 { \
209 fnT fn = nullptr; \
210 return fn; \
211 } \
212 else { \
213 fnT fn = __name__##_strided_impl<T>; \
214 return fn; \
215 } \
216 } \
217 }; \
218 \
219 void populate_##__name__##_dispatch_vectors(void) \
220 { \
221 ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
222 ContigFactory>( \
223 __name__##_contig_dispatch_vector); \
224 ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
225 StridedFactory>( \
226 __name__##_strided_dispatch_vector); \
227 ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
228 __name__##_output_typeid_vector); \
229 };
230
235#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
236 template <typename argT1, typename argT2, typename resT, \
237 unsigned int vec_sz, unsigned int n_vecs> \
238 class __name__##_contig_kernel; \
239 \
240 template <typename argTy1, typename argTy2> \
241 sycl::event __name__##_contig_impl( \
242 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
243 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
244 char *res_p, py::ssize_t res_offset, \
245 const std::vector<sycl::event> &depends = {}) \
246 { \
247 return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
248 ContigFunctor, \
249 __name__##_contig_kernel>( \
250 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
251 res_offset, depends); \
252 } \
253 \
254 template <typename fnT, typename T1, typename T2> \
255 struct ContigFactory \
256 { \
257 fnT get() \
258 { \
259 if constexpr (std::is_same_v< \
260 typename OutputType<T1, T2>::value_type, void>) \
261 { \
262 \
263 fnT fn = nullptr; \
264 return fn; \
265 } \
266 else { \
267 fnT fn = __name__##_contig_impl<T1, T2>; \
268 return fn; \
269 } \
270 } \
271 }; \
272 \
273 template <typename fnT, typename T1, typename T2> \
274 struct TypeMapFactory \
275 { \
276 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
277 { \
278 using rT = typename OutputType<T1, T2>::value_type; \
279 return td_ns::GetTypeid<rT>{}.get(); \
280 } \
281 }; \
282 \
283 template <typename T1, typename T2, typename resT, typename IndexerT> \
284 class __name__##_strided_kernel; \
285 \
286 template <typename argTy1, typename argTy2> \
287 sycl::event __name__##_strided_impl( \
288 sycl::queue &exec_q, size_t nelems, int nd, \
289 const py::ssize_t *shape_and_strides, const char *arg1_p, \
290 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
291 char *res_p, py::ssize_t res_offset, \
292 const std::vector<sycl::event> &depends, \
293 const std::vector<sycl::event> &additional_depends) \
294 { \
295 return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
296 StridedFunctor, \
297 __name__##_strided_kernel>( \
298 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
299 arg2_p, arg2_offset, res_p, res_offset, depends, \
300 additional_depends); \
301 } \
302 \
303 template <typename fnT, typename T1, typename T2> \
304 struct StridedFactory \
305 { \
306 fnT get() \
307 { \
308 if constexpr (std::is_same_v< \
309 typename OutputType<T1, T2>::value_type, void>) \
310 { \
311 fnT fn = nullptr; \
312 return fn; \
313 } \
314 else { \
315 fnT fn = __name__##_strided_impl<T1, T2>; \
316 return fn; \
317 } \
318 } \
319 }; \
320 \
321 void populate_##__name__##_dispatch_tables(void) \
322 { \
323 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
324 ContigFactory>( \
325 __name__##_contig_dispatch_table); \
326 ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
327 StridedFactory>( \
328 __name__##_strided_dispatch_table); \
329 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
330 __name__##_output_typeid_table); \
331 };