DPNP C++ backend kernel library
0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
// this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
// this list of conditions and the following disclaimer in the documentation
11
// and/or other materials provided with the distribution.
12
// - Neither the name of the copyright holder nor the names of its contributors
13
// may be used to endorse or promote products derived from this software
14
// without specific prior written permission.
15
//
16
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26
// THE POSSIBILITY OF SUCH DAMAGE.
27
//*****************************************************************************
28
29
#pragma once
30
31
#include <type_traits>
32
#include <utility>
33
#include <vector>
34
35
#include <pybind11/pybind11.h>
36
37
// utils extension header
38
#include "ext/common.hpp"
39
40
namespace
ext_ns = ext::common;
41
46
#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
47
template <typename T1, typename T2, unsigned int vec_sz, \
48
unsigned int n_vecs> \
49
class __name__##_contig_kernel; \
50
\
51
template <typename argTy> \
52
sycl::event __name__##_contig_impl( \
53
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
54
const std::vector<sycl::event> &depends = {}) \
55
{ \
56
return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
57
__name__##_contig_kernel>( \
58
exec_q, nelems, arg_p, res_p, depends); \
59
} \
60
\
61
template <typename fnT, typename T> \
62
struct ContigFactory \
63
{ \
64
fnT get() \
65
{ \
66
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
67
void>) { \
68
fnT fn = nullptr; \
69
return fn; \
70
} \
71
else { \
72
fnT fn = __name__##_contig_impl<T>; \
73
return fn; \
74
} \
75
} \
76
}; \
77
\
78
template <typename fnT, typename T> \
79
struct TypeMapFactory \
80
{ \
81
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
82
{ \
83
using rT = typename OutputType<T>::value_type; \
84
return td_ns::GetTypeid<rT>{}.get(); \
85
} \
86
}; \
87
\
88
template <typename T1, typename T2, typename T3> \
89
class __name__##_strided_kernel; \
90
\
91
template <typename argTy> \
92
sycl::event __name__##_strided_impl( \
93
sycl::queue &exec_q, size_t nelems, int nd, \
94
const py::ssize_t *shape_and_strides, const char *arg_p, \
95
py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
96
const std::vector<sycl::event> &depends, \
97
const std::vector<sycl::event> &additional_depends) \
98
{ \
99
return ew_cmn_ns::unary_strided_impl< \
100
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
101
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
102
res_offset, depends, additional_depends); \
103
} \
104
\
105
template <typename fnT, typename T> \
106
struct StridedFactory \
107
{ \
108
fnT get() \
109
{ \
110
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
111
void>) { \
112
fnT fn = nullptr; \
113
return fn; \
114
} \
115
else { \
116
fnT fn = __name__##_strided_impl<T>; \
117
return fn; \
118
} \
119
} \
120
}; \
121
\
122
void populate_##__name__##_dispatch_vectors(void) \
123
{ \
124
ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
125
ContigFactory>( \
126
__name__##_contig_dispatch_vector); \
127
ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
128
StridedFactory>( \
129
__name__##_strided_dispatch_vector); \
130
ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
131
__name__##_output_typeid_vector); \
132
};
133
138
#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
139
template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
140
unsigned int n_vecs> \
141
class __name__##_contig_kernel; \
142
\
143
template <typename argTy> \
144
sycl::event __name__##_contig_impl( \
145
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
146
char *res2_p, const std::vector<sycl::event> &depends = {}) \
147
{ \
148
return ew_cmn_ns::unary_two_outputs_contig_impl< \
149
argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
150
exec_q, nelems, arg_p, res1_p, res2_p, depends); \
151
} \
152
\
153
template <typename fnT, typename T> \
154
struct ContigFactory \
155
{ \
156
fnT get() \
157
{ \
158
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
159
void> || \
160
std::is_same_v<typename OutputType<T>::value_type2, \
161
void>) \
162
{ \
163
fnT fn = nullptr; \
164
return fn; \
165
} \
166
else { \
167
fnT fn = __name__##_contig_impl<T>; \
168
return fn; \
169
} \
170
} \
171
}; \
172
\
173
template <typename fnT, typename T> \
174
struct TypeMapFactory \
175
{ \
176
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
177
std::pair<int, int>> \
178
get() \
179
{ \
180
using rT1 = typename OutputType<T>::value_type1; \
181
using rT2 = typename OutputType<T>::value_type2; \
182
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
183
td_ns::GetTypeid<rT2>{}.get()); \
184
} \
185
}; \
186
\
187
template <typename T1, typename T2, typename T3, typename T4> \
188
class __name__##_strided_kernel; \
189
\
190
template <typename argTy> \
191
sycl::event __name__##_strided_impl( \
192
sycl::queue &exec_q, size_t nelems, int nd, \
193
const py::ssize_t *shape_and_strides, const char *arg_p, \
194
py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
195
char *res2_p, py::ssize_t res2_offset, \
196
const std::vector<sycl::event> &depends, \
197
const std::vector<sycl::event> &additional_depends) \
198
{ \
199
return ew_cmn_ns::unary_two_outputs_strided_impl< \
200
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
201
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
202
res1_offset, res2_p, res2_offset, depends, additional_depends); \
203
} \
204
\
205
template <typename fnT, typename T> \
206
struct StridedFactory \
207
{ \
208
fnT get() \
209
{ \
210
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
211
void> || \
212
std::is_same_v<typename OutputType<T>::value_type2, \
213
void>) \
214
{ \
215
fnT fn = nullptr; \
216
return fn; \
217
} \
218
else { \
219
fnT fn = __name__##_strided_impl<T>; \
220
return fn; \
221
} \
222
} \
223
}; \
224
\
225
void populate_##__name__##_dispatch_vectors(void) \
226
{ \
227
ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
228
ContigFactory>( \
229
__name__##_contig_dispatch_vector); \
230
ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
231
StridedFactory>( \
232
__name__##_strided_dispatch_vector); \
233
ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
234
__name__##_output_typeid_vector); \
235
};
236
241
#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
242
template <typename argT1, typename argT2, typename resT, \
243
unsigned int vec_sz, unsigned int n_vecs> \
244
class __name__##_contig_kernel; \
245
\
246
template <typename argTy1, typename argTy2> \
247
sycl::event __name__##_contig_impl( \
248
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
249
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
250
char *res_p, py::ssize_t res_offset, \
251
const std::vector<sycl::event> &depends = {}) \
252
{ \
253
return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
254
ContigFunctor, \
255
__name__##_contig_kernel>( \
256
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
257
res_offset, depends); \
258
} \
259
\
260
template <typename fnT, typename T1, typename T2> \
261
struct ContigFactory \
262
{ \
263
fnT get() \
264
{ \
265
if constexpr (std::is_same_v< \
266
typename OutputType<T1, T2>::value_type, void>) \
267
{ \
268
\
269
fnT fn = nullptr; \
270
return fn; \
271
} \
272
else { \
273
fnT fn = __name__##_contig_impl<T1, T2>; \
274
return fn; \
275
} \
276
} \
277
}; \
278
\
279
template <typename fnT, typename T1, typename T2> \
280
struct TypeMapFactory \
281
{ \
282
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
283
{ \
284
using rT = typename OutputType<T1, T2>::value_type; \
285
return td_ns::GetTypeid<rT>{}.get(); \
286
} \
287
}; \
288
\
289
template <typename T1, typename T2, typename resT, typename IndexerT> \
290
class __name__##_strided_kernel; \
291
\
292
template <typename argTy1, typename argTy2> \
293
sycl::event __name__##_strided_impl( \
294
sycl::queue &exec_q, size_t nelems, int nd, \
295
const py::ssize_t *shape_and_strides, const char *arg1_p, \
296
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
297
char *res_p, py::ssize_t res_offset, \
298
const std::vector<sycl::event> &depends, \
299
const std::vector<sycl::event> &additional_depends) \
300
{ \
301
return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
302
StridedFunctor, \
303
__name__##_strided_kernel>( \
304
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
305
arg2_p, arg2_offset, res_p, res_offset, depends, \
306
additional_depends); \
307
} \
308
\
309
template <typename fnT, typename T1, typename T2> \
310
struct StridedFactory \
311
{ \
312
fnT get() \
313
{ \
314
if constexpr (std::is_same_v< \
315
typename OutputType<T1, T2>::value_type, void>) \
316
{ \
317
fnT fn = nullptr; \
318
return fn; \
319
} \
320
else { \
321
fnT fn = __name__##_strided_impl<T1, T2>; \
322
return fn; \
323
} \
324
} \
325
}; \
326
\
327
void populate_##__name__##_dispatch_tables(void) \
328
{ \
329
ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
330
ContigFactory>( \
331
__name__##_contig_dispatch_table); \
332
ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
333
StridedFactory>( \
334
__name__##_strided_dispatch_table); \
335
ext_ns::init_dispatch_table<int, TypeMapFactory>( \
336
__name__##_output_typeid_table); \
337
};
extensions
ufunc
elementwise_functions
populate.hpp
Generated by
1.12.0