DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <type_traits>
32#include <utility>
33#include <vector>
34
35#include <pybind11/pybind11.h>
36
37// utils extension header
38#include "ext/common.hpp"
39
40namespace ext_ns = ext::common;
41
46#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
47 template <typename T1, typename T2, unsigned int vec_sz, \
48 unsigned int n_vecs> \
49 class __name__##_contig_kernel; \
50 \
51 template <typename argTy> \
52 sycl::event __name__##_contig_impl( \
53 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
54 const std::vector<sycl::event> &depends = {}) \
55 { \
56 return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
57 __name__##_contig_kernel>( \
58 exec_q, nelems, arg_p, res_p, depends); \
59 } \
60 \
61 template <typename fnT, typename T> \
62 struct ContigFactory \
63 { \
64 fnT get() \
65 { \
66 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
67 void>) { \
68 fnT fn = nullptr; \
69 return fn; \
70 } \
71 else { \
72 fnT fn = __name__##_contig_impl<T>; \
73 return fn; \
74 } \
75 } \
76 }; \
77 \
78 template <typename fnT, typename T> \
79 struct TypeMapFactory \
80 { \
81 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
82 { \
83 using rT = typename OutputType<T>::value_type; \
84 return td_ns::GetTypeid<rT>{}.get(); \
85 } \
86 }; \
87 \
88 template <typename T1, typename T2, typename T3> \
89 class __name__##_strided_kernel; \
90 \
91 template <typename argTy> \
92 sycl::event __name__##_strided_impl( \
93 sycl::queue &exec_q, size_t nelems, int nd, \
94 const py::ssize_t *shape_and_strides, const char *arg_p, \
95 py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
96 const std::vector<sycl::event> &depends, \
97 const std::vector<sycl::event> &additional_depends) \
98 { \
99 return ew_cmn_ns::unary_strided_impl< \
100 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
101 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
102 res_offset, depends, additional_depends); \
103 } \
104 \
105 template <typename fnT, typename T> \
106 struct StridedFactory \
107 { \
108 fnT get() \
109 { \
110 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
111 void>) { \
112 fnT fn = nullptr; \
113 return fn; \
114 } \
115 else { \
116 fnT fn = __name__##_strided_impl<T>; \
117 return fn; \
118 } \
119 } \
120 }; \
121 \
122 void populate_##__name__##_dispatch_vectors(void) \
123 { \
124 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
125 ContigFactory>( \
126 __name__##_contig_dispatch_vector); \
127 ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
128 StridedFactory>( \
129 __name__##_strided_dispatch_vector); \
130 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
131 __name__##_output_typeid_vector); \
132 };
133
138#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
139 template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
140 unsigned int n_vecs> \
141 class __name__##_contig_kernel; \
142 \
143 template <typename argTy> \
144 sycl::event __name__##_contig_impl( \
145 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
146 char *res2_p, const std::vector<sycl::event> &depends = {}) \
147 { \
148 return ew_cmn_ns::unary_two_outputs_contig_impl< \
149 argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
150 exec_q, nelems, arg_p, res1_p, res2_p, depends); \
151 } \
152 \
153 template <typename fnT, typename T> \
154 struct ContigFactory \
155 { \
156 fnT get() \
157 { \
158 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
159 void> || \
160 std::is_same_v<typename OutputType<T>::value_type2, \
161 void>) { \
162 fnT fn = nullptr; \
163 return fn; \
164 } \
165 else { \
166 fnT fn = __name__##_contig_impl<T>; \
167 return fn; \
168 } \
169 } \
170 }; \
171 \
172 template <typename fnT, typename T> \
173 struct TypeMapFactory \
174 { \
175 std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
176 std::pair<int, int>> \
177 get() \
178 { \
179 using rT1 = typename OutputType<T>::value_type1; \
180 using rT2 = typename OutputType<T>::value_type2; \
181 return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
182 td_ns::GetTypeid<rT2>{}.get()); \
183 } \
184 }; \
185 \
186 template <typename T1, typename T2, typename T3, typename T4> \
187 class __name__##_strided_kernel; \
188 \
189 template <typename argTy> \
190 sycl::event __name__##_strided_impl( \
191 sycl::queue &exec_q, size_t nelems, int nd, \
192 const py::ssize_t *shape_and_strides, const char *arg_p, \
193 py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
194 char *res2_p, py::ssize_t res2_offset, \
195 const std::vector<sycl::event> &depends, \
196 const std::vector<sycl::event> &additional_depends) \
197 { \
198 return ew_cmn_ns::unary_two_outputs_strided_impl< \
199 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
200 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
201 res1_offset, res2_p, res2_offset, depends, additional_depends); \
202 } \
203 \
204 template <typename fnT, typename T> \
205 struct StridedFactory \
206 { \
207 fnT get() \
208 { \
209 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
210 void> || \
211 std::is_same_v<typename OutputType<T>::value_type2, \
212 void>) { \
213 fnT fn = nullptr; \
214 return fn; \
215 } \
216 else { \
217 fnT fn = __name__##_strided_impl<T>; \
218 return fn; \
219 } \
220 } \
221 }; \
222 \
223 void populate_##__name__##_dispatch_vectors(void) \
224 { \
225 ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
226 ContigFactory>( \
227 __name__##_contig_dispatch_vector); \
228 ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
229 StridedFactory>( \
230 __name__##_strided_dispatch_vector); \
231 ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
232 __name__##_output_typeid_vector); \
233 };
234
239#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
240 template <typename argT1, typename argT2, typename resT, \
241 unsigned int vec_sz, unsigned int n_vecs> \
242 class __name__##_contig_kernel; \
243 \
244 template <typename argTy1, typename argTy2> \
245 sycl::event __name__##_contig_impl( \
246 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
247 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
248 char *res_p, py::ssize_t res_offset, \
249 const std::vector<sycl::event> &depends = {}) \
250 { \
251 return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
252 ContigFunctor, \
253 __name__##_contig_kernel>( \
254 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
255 res_offset, depends); \
256 } \
257 \
258 template <typename fnT, typename T1, typename T2> \
259 struct ContigFactory \
260 { \
261 fnT get() \
262 { \
263 if constexpr (std::is_same_v< \
264 typename OutputType<T1, T2>::value_type, \
265 void>) { \
266 \
267 fnT fn = nullptr; \
268 return fn; \
269 } \
270 else { \
271 fnT fn = __name__##_contig_impl<T1, T2>; \
272 return fn; \
273 } \
274 } \
275 }; \
276 \
277 template <typename fnT, typename T1, typename T2> \
278 struct TypeMapFactory \
279 { \
280 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
281 { \
282 using rT = typename OutputType<T1, T2>::value_type; \
283 return td_ns::GetTypeid<rT>{}.get(); \
284 } \
285 }; \
286 \
287 template <typename T1, typename T2, typename resT, typename IndexerT> \
288 class __name__##_strided_kernel; \
289 \
290 template <typename argTy1, typename argTy2> \
291 sycl::event __name__##_strided_impl( \
292 sycl::queue &exec_q, size_t nelems, int nd, \
293 const py::ssize_t *shape_and_strides, const char *arg1_p, \
294 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
295 char *res_p, py::ssize_t res_offset, \
296 const std::vector<sycl::event> &depends, \
297 const std::vector<sycl::event> &additional_depends) \
298 { \
299 return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
300 StridedFunctor, \
301 __name__##_strided_kernel>( \
302 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
303 arg2_p, arg2_offset, res_p, res_offset, depends, \
304 additional_depends); \
305 } \
306 \
307 template <typename fnT, typename T1, typename T2> \
308 struct StridedFactory \
309 { \
310 fnT get() \
311 { \
312 if constexpr (std::is_same_v< \
313 typename OutputType<T1, T2>::value_type, \
314 void>) { \
315 fnT fn = nullptr; \
316 return fn; \
317 } \
318 else { \
319 fnT fn = __name__##_strided_impl<T1, T2>; \
320 return fn; \
321 } \
322 } \
323 }; \
324 \
325 void populate_##__name__##_dispatch_tables(void) \
326 { \
327 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
328 ContigFactory>( \
329 __name__##_contig_dispatch_table); \
330 ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
331 StridedFactory>( \
332 __name__##_strided_dispatch_table); \
333 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
334 __name__##_output_typeid_table); \
335 };
336
341#define MACRO_POPULATE_DISPATCH_2OUTS_TABLES(__name__) \
342 template <typename argT1, typename argT2, typename resT1, typename resT2, \
343 unsigned int vec_sz, unsigned int n_vecs> \
344 class __name__##_contig_kernel; \
345 \
346 template <typename argTy1, typename argTy2> \
347 sycl::event __name__##_contig_impl( \
348 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
349 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
350 char *res1_p, py::ssize_t res1_offset, char *res2_p, \
351 py::ssize_t res2_offset, const std::vector<sycl::event> &depends = {}) \
352 { \
353 return ew_cmn_ns::binary_two_outputs_contig_impl< \
354 argTy1, argTy2, OutputType, ContigFunctor, \
355 __name__##_contig_kernel>( \
356 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res1_p, \
357 res1_offset, res2_p, res2_offset, depends); \
358 } \
359 \
360 template <typename fnT, typename T1, typename T2> \
361 struct ContigFactory \
362 { \
363 fnT get() \
364 { \
365 if constexpr (std::is_same_v< \
366 typename OutputType<T1, T2>::value_type1, \
367 void> || \
368 std::is_same_v< \
369 typename OutputType<T1, T2>::value_type2, \
370 void>) { \
371 \
372 fnT fn = nullptr; \
373 return fn; \
374 } \
375 else { \
376 fnT fn = __name__##_contig_impl<T1, T2>; \
377 return fn; \
378 } \
379 } \
380 }; \
381 \
382 template <typename fnT, typename T1, typename T2> \
383 struct TypeMapFactory \
384 { \
385 std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
386 std::pair<int, int>> \
387 get() \
388 { \
389 using rT1 = typename OutputType<T1, T2>::value_type1; \
390 using rT2 = typename OutputType<T1, T2>::value_type2; \
391 return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
392 td_ns::GetTypeid<rT2>{}.get()); \
393 } \
394 }; \
395 \
396 template <typename T1, typename T2, typename resT1, typename resT2, \
397 typename IndexerT> \
398 class __name__##_strided_kernel; \
399 \
400 template <typename argTy1, typename argTy2> \
401 sycl::event __name__##_strided_impl( \
402 sycl::queue &exec_q, size_t nelems, int nd, \
403 const py::ssize_t *shape_and_strides, const char *arg1_p, \
404 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
405 char *res1_p, py::ssize_t res1_offset, char *res2_p, \
406 py::ssize_t res2_offset, const std::vector<sycl::event> &depends, \
407 const std::vector<sycl::event> &additional_depends) \
408 { \
409 return ew_cmn_ns::binary_two_outputs_strided_impl< \
410 argTy1, argTy2, OutputType, StridedFunctor, \
411 __name__##_strided_kernel>( \
412 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
413 arg2_p, arg2_offset, res1_p, res1_offset, res2_p, res2_offset, \
414 depends, additional_depends); \
415 } \
416 \
417 template <typename fnT, typename T1, typename T2> \
418 struct StridedFactory \
419 { \
420 fnT get() \
421 { \
422 if constexpr (std::is_same_v< \
423 typename OutputType<T1, T2>::value_type1, \
424 void> || \
425 std::is_same_v< \
426 typename OutputType<T1, T2>::value_type2, \
427 void>) { \
428 fnT fn = nullptr; \
429 return fn; \
430 } \
431 else { \
432 fnT fn = __name__##_strided_impl<T1, T2>; \
433 return fn; \
434 } \
435 } \
436 }; \
437 \
438 void populate_##__name__##_dispatch_tables(void) \
439 { \
440 ext_ns::init_dispatch_table<binary_two_outputs_contig_impl_fn_ptr_t, \
441 ContigFactory>( \
442 __name__##_contig_dispatch_table); \
443 ext_ns::init_dispatch_table<binary_two_outputs_strided_impl_fn_ptr_t, \
444 StridedFactory>( \
445 __name__##_strided_dispatch_table); \
446 ext_ns::init_dispatch_table<std::pair<int, int>, TypeMapFactory>( \
447 __name__##_output_typeid_table); \
448 };