DPNP C++ backend kernel library
0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
// this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
// this list of conditions and the following disclaimer in the documentation
11
// and/or other materials provided with the distribution.
12
// - Neither the name of the copyright holder nor the names of its contributors
13
// may be used to endorse or promote products derived from this software
14
// without specific prior written permission.
15
//
16
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26
// THE POSSIBILITY OF SUCH DAMAGE.
27
//*****************************************************************************
28
29
#pragma once
30
31
// utils extension header
32
#include "ext/common.hpp"
33
34
namespace
ext_ns = ext::common;
35
40
#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
41
template <typename T1, typename T2, unsigned int vec_sz, \
42
unsigned int n_vecs> \
43
class __name__##_contig_kernel; \
44
\
45
template <typename argTy> \
46
sycl::event __name__##_contig_impl( \
47
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
48
const std::vector<sycl::event> &depends = {}) \
49
{ \
50
return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
51
__name__##_contig_kernel>( \
52
exec_q, nelems, arg_p, res_p, depends); \
53
} \
54
\
55
template <typename fnT, typename T> \
56
struct ContigFactory \
57
{ \
58
fnT get() \
59
{ \
60
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
61
void>) { \
62
fnT fn = nullptr; \
63
return fn; \
64
} \
65
else { \
66
fnT fn = __name__##_contig_impl<T>; \
67
return fn; \
68
} \
69
} \
70
}; \
71
\
72
template <typename fnT, typename T> \
73
struct TypeMapFactory \
74
{ \
75
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
76
{ \
77
using rT = typename OutputType<T>::value_type; \
78
return td_ns::GetTypeid<rT>{}.get(); \
79
} \
80
}; \
81
\
82
template <typename T1, typename T2, typename T3> \
83
class __name__##_strided_kernel; \
84
\
85
template <typename argTy> \
86
sycl::event __name__##_strided_impl( \
87
sycl::queue &exec_q, size_t nelems, int nd, \
88
const py::ssize_t *shape_and_strides, const char *arg_p, \
89
py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
90
const std::vector<sycl::event> &depends, \
91
const std::vector<sycl::event> &additional_depends) \
92
{ \
93
return ew_cmn_ns::unary_strided_impl< \
94
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
95
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
96
res_offset, depends, additional_depends); \
97
} \
98
\
99
template <typename fnT, typename T> \
100
struct StridedFactory \
101
{ \
102
fnT get() \
103
{ \
104
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
105
void>) { \
106
fnT fn = nullptr; \
107
return fn; \
108
} \
109
else { \
110
fnT fn = __name__##_strided_impl<T>; \
111
return fn; \
112
} \
113
} \
114
}; \
115
\
116
void populate_##__name__##_dispatch_vectors(void) \
117
{ \
118
ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
119
ContigFactory>( \
120
__name__##_contig_dispatch_vector); \
121
ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
122
StridedFactory>( \
123
__name__##_strided_dispatch_vector); \
124
ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
125
__name__##_output_typeid_vector); \
126
};
127
132
#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
133
template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
134
unsigned int n_vecs> \
135
class __name__##_contig_kernel; \
136
\
137
template <typename argTy> \
138
sycl::event __name__##_contig_impl( \
139
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
140
char *res2_p, const std::vector<sycl::event> &depends = {}) \
141
{ \
142
return ew_cmn_ns::unary_two_outputs_contig_impl< \
143
argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
144
exec_q, nelems, arg_p, res1_p, res2_p, depends); \
145
} \
146
\
147
template <typename fnT, typename T> \
148
struct ContigFactory \
149
{ \
150
fnT get() \
151
{ \
152
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
153
void> || \
154
std::is_same_v<typename OutputType<T>::value_type2, \
155
void>) \
156
{ \
157
fnT fn = nullptr; \
158
return fn; \
159
} \
160
else { \
161
fnT fn = __name__##_contig_impl<T>; \
162
return fn; \
163
} \
164
} \
165
}; \
166
\
167
template <typename fnT, typename T> \
168
struct TypeMapFactory \
169
{ \
170
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
171
std::pair<int, int>> \
172
get() \
173
{ \
174
using rT1 = typename OutputType<T>::value_type1; \
175
using rT2 = typename OutputType<T>::value_type2; \
176
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
177
td_ns::GetTypeid<rT2>{}.get()); \
178
} \
179
}; \
180
\
181
template <typename T1, typename T2, typename T3, typename T4> \
182
class __name__##_strided_kernel; \
183
\
184
template <typename argTy> \
185
sycl::event __name__##_strided_impl( \
186
sycl::queue &exec_q, size_t nelems, int nd, \
187
const py::ssize_t *shape_and_strides, const char *arg_p, \
188
py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
189
char *res2_p, py::ssize_t res2_offset, \
190
const std::vector<sycl::event> &depends, \
191
const std::vector<sycl::event> &additional_depends) \
192
{ \
193
return ew_cmn_ns::unary_two_outputs_strided_impl< \
194
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
195
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
196
res1_offset, res2_p, res2_offset, depends, additional_depends); \
197
} \
198
\
199
template <typename fnT, typename T> \
200
struct StridedFactory \
201
{ \
202
fnT get() \
203
{ \
204
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
205
void> || \
206
std::is_same_v<typename OutputType<T>::value_type2, \
207
void>) \
208
{ \
209
fnT fn = nullptr; \
210
return fn; \
211
} \
212
else { \
213
fnT fn = __name__##_strided_impl<T>; \
214
return fn; \
215
} \
216
} \
217
}; \
218
\
219
void populate_##__name__##_dispatch_vectors(void) \
220
{ \
221
ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
222
ContigFactory>( \
223
__name__##_contig_dispatch_vector); \
224
ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
225
StridedFactory>( \
226
__name__##_strided_dispatch_vector); \
227
ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
228
__name__##_output_typeid_vector); \
229
};
230
235
#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
236
template <typename argT1, typename argT2, typename resT, \
237
unsigned int vec_sz, unsigned int n_vecs> \
238
class __name__##_contig_kernel; \
239
\
240
template <typename argTy1, typename argTy2> \
241
sycl::event __name__##_contig_impl( \
242
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
243
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
244
char *res_p, py::ssize_t res_offset, \
245
const std::vector<sycl::event> &depends = {}) \
246
{ \
247
return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
248
ContigFunctor, \
249
__name__##_contig_kernel>( \
250
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
251
res_offset, depends); \
252
} \
253
\
254
template <typename fnT, typename T1, typename T2> \
255
struct ContigFactory \
256
{ \
257
fnT get() \
258
{ \
259
if constexpr (std::is_same_v< \
260
typename OutputType<T1, T2>::value_type, void>) \
261
{ \
262
\
263
fnT fn = nullptr; \
264
return fn; \
265
} \
266
else { \
267
fnT fn = __name__##_contig_impl<T1, T2>; \
268
return fn; \
269
} \
270
} \
271
}; \
272
\
273
template <typename fnT, typename T1, typename T2> \
274
struct TypeMapFactory \
275
{ \
276
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
277
{ \
278
using rT = typename OutputType<T1, T2>::value_type; \
279
return td_ns::GetTypeid<rT>{}.get(); \
280
} \
281
}; \
282
\
283
template <typename T1, typename T2, typename resT, typename IndexerT> \
284
class __name__##_strided_kernel; \
285
\
286
template <typename argTy1, typename argTy2> \
287
sycl::event __name__##_strided_impl( \
288
sycl::queue &exec_q, size_t nelems, int nd, \
289
const py::ssize_t *shape_and_strides, const char *arg1_p, \
290
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
291
char *res_p, py::ssize_t res_offset, \
292
const std::vector<sycl::event> &depends, \
293
const std::vector<sycl::event> &additional_depends) \
294
{ \
295
return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
296
StridedFunctor, \
297
__name__##_strided_kernel>( \
298
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
299
arg2_p, arg2_offset, res_p, res_offset, depends, \
300
additional_depends); \
301
} \
302
\
303
template <typename fnT, typename T1, typename T2> \
304
struct StridedFactory \
305
{ \
306
fnT get() \
307
{ \
308
if constexpr (std::is_same_v< \
309
typename OutputType<T1, T2>::value_type, void>) \
310
{ \
311
fnT fn = nullptr; \
312
return fn; \
313
} \
314
else { \
315
fnT fn = __name__##_strided_impl<T1, T2>; \
316
return fn; \
317
} \
318
} \
319
}; \
320
\
321
void populate_##__name__##_dispatch_tables(void) \
322
{ \
323
ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
324
ContigFactory>( \
325
__name__##_contig_dispatch_table); \
326
ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
327
StridedFactory>( \
328
__name__##_strided_dispatch_table); \
329
ext_ns::init_dispatch_table<int, TypeMapFactory>( \
330
__name__##_output_typeid_table); \
331
};
extensions
ufunc
elementwise_functions
populate.hpp
Generated by
1.12.0