DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <type_traits>
32#include <utility>
33#include <vector>
34
35#include <pybind11/pybind11.h>
36
37// utils extension header
38#include "ext/common.hpp"
39
40namespace ext_ns = ext::common;
41
46#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
47 template <typename T1, typename T2, unsigned int vec_sz, \
48 unsigned int n_vecs> \
49 class __name__##_contig_kernel; \
50 \
51 template <typename argTy> \
52 sycl::event __name__##_contig_impl( \
53 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
54 const std::vector<sycl::event> &depends = {}) \
55 { \
56 return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
57 __name__##_contig_kernel>( \
58 exec_q, nelems, arg_p, res_p, depends); \
59 } \
60 \
61 template <typename fnT, typename T> \
62 struct ContigFactory \
63 { \
64 fnT get() \
65 { \
66 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
67 void>) { \
68 fnT fn = nullptr; \
69 return fn; \
70 } \
71 else { \
72 fnT fn = __name__##_contig_impl<T>; \
73 return fn; \
74 } \
75 } \
76 }; \
77 \
78 template <typename fnT, typename T> \
79 struct TypeMapFactory \
80 { \
81 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
82 { \
83 using rT = typename OutputType<T>::value_type; \
84 return td_ns::GetTypeid<rT>{}.get(); \
85 } \
86 }; \
87 \
88 template <typename T1, typename T2, typename T3> \
89 class __name__##_strided_kernel; \
90 \
91 template <typename argTy> \
92 sycl::event __name__##_strided_impl( \
93 sycl::queue &exec_q, size_t nelems, int nd, \
94 const py::ssize_t *shape_and_strides, const char *arg_p, \
95 py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
96 const std::vector<sycl::event> &depends, \
97 const std::vector<sycl::event> &additional_depends) \
98 { \
99 return ew_cmn_ns::unary_strided_impl< \
100 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
101 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
102 res_offset, depends, additional_depends); \
103 } \
104 \
105 template <typename fnT, typename T> \
106 struct StridedFactory \
107 { \
108 fnT get() \
109 { \
110 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
111 void>) { \
112 fnT fn = nullptr; \
113 return fn; \
114 } \
115 else { \
116 fnT fn = __name__##_strided_impl<T>; \
117 return fn; \
118 } \
119 } \
120 }; \
121 \
122 void populate_##__name__##_dispatch_vectors(void) \
123 { \
124 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
125 ContigFactory>( \
126 __name__##_contig_dispatch_vector); \
127 ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
128 StridedFactory>( \
129 __name__##_strided_dispatch_vector); \
130 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
131 __name__##_output_typeid_vector); \
132 };
133
138#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
139 template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
140 unsigned int n_vecs> \
141 class __name__##_contig_kernel; \
142 \
143 template <typename argTy> \
144 sycl::event __name__##_contig_impl( \
145 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
146 char *res2_p, const std::vector<sycl::event> &depends = {}) \
147 { \
148 return ew_cmn_ns::unary_two_outputs_contig_impl< \
149 argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
150 exec_q, nelems, arg_p, res1_p, res2_p, depends); \
151 } \
152 \
153 template <typename fnT, typename T> \
154 struct ContigFactory \
155 { \
156 fnT get() \
157 { \
158 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
159 void> || \
160 std::is_same_v<typename OutputType<T>::value_type2, \
161 void>) \
162 { \
163 fnT fn = nullptr; \
164 return fn; \
165 } \
166 else { \
167 fnT fn = __name__##_contig_impl<T>; \
168 return fn; \
169 } \
170 } \
171 }; \
172 \
173 template <typename fnT, typename T> \
174 struct TypeMapFactory \
175 { \
176 std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
177 std::pair<int, int>> \
178 get() \
179 { \
180 using rT1 = typename OutputType<T>::value_type1; \
181 using rT2 = typename OutputType<T>::value_type2; \
182 return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
183 td_ns::GetTypeid<rT2>{}.get()); \
184 } \
185 }; \
186 \
187 template <typename T1, typename T2, typename T3, typename T4> \
188 class __name__##_strided_kernel; \
189 \
190 template <typename argTy> \
191 sycl::event __name__##_strided_impl( \
192 sycl::queue &exec_q, size_t nelems, int nd, \
193 const py::ssize_t *shape_and_strides, const char *arg_p, \
194 py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
195 char *res2_p, py::ssize_t res2_offset, \
196 const std::vector<sycl::event> &depends, \
197 const std::vector<sycl::event> &additional_depends) \
198 { \
199 return ew_cmn_ns::unary_two_outputs_strided_impl< \
200 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
201 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
202 res1_offset, res2_p, res2_offset, depends, additional_depends); \
203 } \
204 \
205 template <typename fnT, typename T> \
206 struct StridedFactory \
207 { \
208 fnT get() \
209 { \
210 if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
211 void> || \
212 std::is_same_v<typename OutputType<T>::value_type2, \
213 void>) \
214 { \
215 fnT fn = nullptr; \
216 return fn; \
217 } \
218 else { \
219 fnT fn = __name__##_strided_impl<T>; \
220 return fn; \
221 } \
222 } \
223 }; \
224 \
225 void populate_##__name__##_dispatch_vectors(void) \
226 { \
227 ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
228 ContigFactory>( \
229 __name__##_contig_dispatch_vector); \
230 ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
231 StridedFactory>( \
232 __name__##_strided_dispatch_vector); \
233 ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
234 __name__##_output_typeid_vector); \
235 };
236
241#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
242 template <typename argT1, typename argT2, typename resT, \
243 unsigned int vec_sz, unsigned int n_vecs> \
244 class __name__##_contig_kernel; \
245 \
246 template <typename argTy1, typename argTy2> \
247 sycl::event __name__##_contig_impl( \
248 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
249 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
250 char *res_p, py::ssize_t res_offset, \
251 const std::vector<sycl::event> &depends = {}) \
252 { \
253 return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
254 ContigFunctor, \
255 __name__##_contig_kernel>( \
256 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
257 res_offset, depends); \
258 } \
259 \
260 template <typename fnT, typename T1, typename T2> \
261 struct ContigFactory \
262 { \
263 fnT get() \
264 { \
265 if constexpr (std::is_same_v< \
266 typename OutputType<T1, T2>::value_type, void>) \
267 { \
268 \
269 fnT fn = nullptr; \
270 return fn; \
271 } \
272 else { \
273 fnT fn = __name__##_contig_impl<T1, T2>; \
274 return fn; \
275 } \
276 } \
277 }; \
278 \
279 template <typename fnT, typename T1, typename T2> \
280 struct TypeMapFactory \
281 { \
282 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
283 { \
284 using rT = typename OutputType<T1, T2>::value_type; \
285 return td_ns::GetTypeid<rT>{}.get(); \
286 } \
287 }; \
288 \
289 template <typename T1, typename T2, typename resT, typename IndexerT> \
290 class __name__##_strided_kernel; \
291 \
292 template <typename argTy1, typename argTy2> \
293 sycl::event __name__##_strided_impl( \
294 sycl::queue &exec_q, size_t nelems, int nd, \
295 const py::ssize_t *shape_and_strides, const char *arg1_p, \
296 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
297 char *res_p, py::ssize_t res_offset, \
298 const std::vector<sycl::event> &depends, \
299 const std::vector<sycl::event> &additional_depends) \
300 { \
301 return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
302 StridedFunctor, \
303 __name__##_strided_kernel>( \
304 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
305 arg2_p, arg2_offset, res_p, res_offset, depends, \
306 additional_depends); \
307 } \
308 \
309 template <typename fnT, typename T1, typename T2> \
310 struct StridedFactory \
311 { \
312 fnT get() \
313 { \
314 if constexpr (std::is_same_v< \
315 typename OutputType<T1, T2>::value_type, void>) \
316 { \
317 fnT fn = nullptr; \
318 return fn; \
319 } \
320 else { \
321 fnT fn = __name__##_strided_impl<T1, T2>; \
322 return fn; \
323 } \
324 } \
325 }; \
326 \
327 void populate_##__name__##_dispatch_tables(void) \
328 { \
329 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
330 ContigFactory>( \
331 __name__##_contig_dispatch_table); \
332 ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
333 StridedFactory>( \
334 __name__##_strided_dispatch_table); \
335 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
336 __name__##_output_typeid_table); \
337 };