DPNP C++ backend kernel library
0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
// this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
// this list of conditions and the following disclaimer in the documentation
11
// and/or other materials provided with the distribution.
12
// - Neither the name of the copyright holder nor the names of its contributors
13
// may be used to endorse or promote products derived from this software
14
// without specific prior written permission.
15
//
16
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26
// THE POSSIBILITY OF SUCH DAMAGE.
27
//*****************************************************************************
28
29
#pragma once
30
31
#include <type_traits>
32
#include <utility>
33
#include <vector>
34
35
#include <pybind11/pybind11.h>
36
37
// utils extension header
38
#include "ext/common.hpp"
39
40
namespace
ext_ns = ext::common;
41
46
#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
47
template <typename T1, typename T2, unsigned int vec_sz, \
48
unsigned int n_vecs> \
49
class __name__##_contig_kernel; \
50
\
51
template <typename argTy> \
52
sycl::event __name__##_contig_impl( \
53
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
54
const std::vector<sycl::event> &depends = {}) \
55
{ \
56
return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
57
__name__##_contig_kernel>( \
58
exec_q, nelems, arg_p, res_p, depends); \
59
} \
60
\
61
template <typename fnT, typename T> \
62
struct ContigFactory \
63
{ \
64
fnT get() \
65
{ \
66
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
67
void>) { \
68
fnT fn = nullptr; \
69
return fn; \
70
} \
71
else { \
72
fnT fn = __name__##_contig_impl<T>; \
73
return fn; \
74
} \
75
} \
76
}; \
77
\
78
template <typename fnT, typename T> \
79
struct TypeMapFactory \
80
{ \
81
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
82
{ \
83
using rT = typename OutputType<T>::value_type; \
84
return td_ns::GetTypeid<rT>{}.get(); \
85
} \
86
}; \
87
\
88
template <typename T1, typename T2, typename T3> \
89
class __name__##_strided_kernel; \
90
\
91
template <typename argTy> \
92
sycl::event __name__##_strided_impl( \
93
sycl::queue &exec_q, size_t nelems, int nd, \
94
const py::ssize_t *shape_and_strides, const char *arg_p, \
95
py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
96
const std::vector<sycl::event> &depends, \
97
const std::vector<sycl::event> &additional_depends) \
98
{ \
99
return ew_cmn_ns::unary_strided_impl< \
100
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
101
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
102
res_offset, depends, additional_depends); \
103
} \
104
\
105
template <typename fnT, typename T> \
106
struct StridedFactory \
107
{ \
108
fnT get() \
109
{ \
110
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
111
void>) { \
112
fnT fn = nullptr; \
113
return fn; \
114
} \
115
else { \
116
fnT fn = __name__##_strided_impl<T>; \
117
return fn; \
118
} \
119
} \
120
}; \
121
\
122
void populate_##__name__##_dispatch_vectors(void) \
123
{ \
124
ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
125
ContigFactory>( \
126
__name__##_contig_dispatch_vector); \
127
ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
128
StridedFactory>( \
129
__name__##_strided_dispatch_vector); \
130
ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
131
__name__##_output_typeid_vector); \
132
};
133
138
#define MACRO_POPULATE_DISPATCH_2OUTS_VECTORS(__name__) \
139
template <typename T1, typename T2, typename T3, unsigned int vec_sz, \
140
unsigned int n_vecs> \
141
class __name__##_contig_kernel; \
142
\
143
template <typename argTy> \
144
sycl::event __name__##_contig_impl( \
145
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res1_p, \
146
char *res2_p, const std::vector<sycl::event> &depends = {}) \
147
{ \
148
return ew_cmn_ns::unary_two_outputs_contig_impl< \
149
argTy, OutputType, ContigFunctor, __name__##_contig_kernel>( \
150
exec_q, nelems, arg_p, res1_p, res2_p, depends); \
151
} \
152
\
153
template <typename fnT, typename T> \
154
struct ContigFactory \
155
{ \
156
fnT get() \
157
{ \
158
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
159
void> || \
160
std::is_same_v<typename OutputType<T>::value_type2, \
161
void>) { \
162
fnT fn = nullptr; \
163
return fn; \
164
} \
165
else { \
166
fnT fn = __name__##_contig_impl<T>; \
167
return fn; \
168
} \
169
} \
170
}; \
171
\
172
template <typename fnT, typename T> \
173
struct TypeMapFactory \
174
{ \
175
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
176
std::pair<int, int>> \
177
get() \
178
{ \
179
using rT1 = typename OutputType<T>::value_type1; \
180
using rT2 = typename OutputType<T>::value_type2; \
181
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
182
td_ns::GetTypeid<rT2>{}.get()); \
183
} \
184
}; \
185
\
186
template <typename T1, typename T2, typename T3, typename T4> \
187
class __name__##_strided_kernel; \
188
\
189
template <typename argTy> \
190
sycl::event __name__##_strided_impl( \
191
sycl::queue &exec_q, size_t nelems, int nd, \
192
const py::ssize_t *shape_and_strides, const char *arg_p, \
193
py::ssize_t arg_offset, char *res1_p, py::ssize_t res1_offset, \
194
char *res2_p, py::ssize_t res2_offset, \
195
const std::vector<sycl::event> &depends, \
196
const std::vector<sycl::event> &additional_depends) \
197
{ \
198
return ew_cmn_ns::unary_two_outputs_strided_impl< \
199
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
200
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res1_p, \
201
res1_offset, res2_p, res2_offset, depends, additional_depends); \
202
} \
203
\
204
template <typename fnT, typename T> \
205
struct StridedFactory \
206
{ \
207
fnT get() \
208
{ \
209
if constexpr (std::is_same_v<typename OutputType<T>::value_type1, \
210
void> || \
211
std::is_same_v<typename OutputType<T>::value_type2, \
212
void>) { \
213
fnT fn = nullptr; \
214
return fn; \
215
} \
216
else { \
217
fnT fn = __name__##_strided_impl<T>; \
218
return fn; \
219
} \
220
} \
221
}; \
222
\
223
void populate_##__name__##_dispatch_vectors(void) \
224
{ \
225
ext_ns::init_dispatch_vector<unary_two_outputs_contig_impl_fn_ptr_t, \
226
ContigFactory>( \
227
__name__##_contig_dispatch_vector); \
228
ext_ns::init_dispatch_vector<unary_two_outputs_strided_impl_fn_ptr_t, \
229
StridedFactory>( \
230
__name__##_strided_dispatch_vector); \
231
ext_ns::init_dispatch_vector<std::pair<int, int>, TypeMapFactory>( \
232
__name__##_output_typeid_vector); \
233
};
234
239
#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
240
template <typename argT1, typename argT2, typename resT, \
241
unsigned int vec_sz, unsigned int n_vecs> \
242
class __name__##_contig_kernel; \
243
\
244
template <typename argTy1, typename argTy2> \
245
sycl::event __name__##_contig_impl( \
246
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
247
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
248
char *res_p, py::ssize_t res_offset, \
249
const std::vector<sycl::event> &depends = {}) \
250
{ \
251
return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
252
ContigFunctor, \
253
__name__##_contig_kernel>( \
254
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
255
res_offset, depends); \
256
} \
257
\
258
template <typename fnT, typename T1, typename T2> \
259
struct ContigFactory \
260
{ \
261
fnT get() \
262
{ \
263
if constexpr (std::is_same_v< \
264
typename OutputType<T1, T2>::value_type, \
265
void>) { \
266
\
267
fnT fn = nullptr; \
268
return fn; \
269
} \
270
else { \
271
fnT fn = __name__##_contig_impl<T1, T2>; \
272
return fn; \
273
} \
274
} \
275
}; \
276
\
277
template <typename fnT, typename T1, typename T2> \
278
struct TypeMapFactory \
279
{ \
280
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
281
{ \
282
using rT = typename OutputType<T1, T2>::value_type; \
283
return td_ns::GetTypeid<rT>{}.get(); \
284
} \
285
}; \
286
\
287
template <typename T1, typename T2, typename resT, typename IndexerT> \
288
class __name__##_strided_kernel; \
289
\
290
template <typename argTy1, typename argTy2> \
291
sycl::event __name__##_strided_impl( \
292
sycl::queue &exec_q, size_t nelems, int nd, \
293
const py::ssize_t *shape_and_strides, const char *arg1_p, \
294
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
295
char *res_p, py::ssize_t res_offset, \
296
const std::vector<sycl::event> &depends, \
297
const std::vector<sycl::event> &additional_depends) \
298
{ \
299
return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
300
StridedFunctor, \
301
__name__##_strided_kernel>( \
302
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
303
arg2_p, arg2_offset, res_p, res_offset, depends, \
304
additional_depends); \
305
} \
306
\
307
template <typename fnT, typename T1, typename T2> \
308
struct StridedFactory \
309
{ \
310
fnT get() \
311
{ \
312
if constexpr (std::is_same_v< \
313
typename OutputType<T1, T2>::value_type, \
314
void>) { \
315
fnT fn = nullptr; \
316
return fn; \
317
} \
318
else { \
319
fnT fn = __name__##_strided_impl<T1, T2>; \
320
return fn; \
321
} \
322
} \
323
}; \
324
\
325
void populate_##__name__##_dispatch_tables(void) \
326
{ \
327
ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
328
ContigFactory>( \
329
__name__##_contig_dispatch_table); \
330
ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
331
StridedFactory>( \
332
__name__##_strided_dispatch_table); \
333
ext_ns::init_dispatch_table<int, TypeMapFactory>( \
334
__name__##_output_typeid_table); \
335
};
336
341
#define MACRO_POPULATE_DISPATCH_2OUTS_TABLES(__name__) \
342
template <typename argT1, typename argT2, typename resT1, typename resT2, \
343
unsigned int vec_sz, unsigned int n_vecs> \
344
class __name__##_contig_kernel; \
345
\
346
template <typename argTy1, typename argTy2> \
347
sycl::event __name__##_contig_impl( \
348
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
349
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
350
char *res1_p, py::ssize_t res1_offset, char *res2_p, \
351
py::ssize_t res2_offset, const std::vector<sycl::event> &depends = {}) \
352
{ \
353
return ew_cmn_ns::binary_two_outputs_contig_impl< \
354
argTy1, argTy2, OutputType, ContigFunctor, \
355
__name__##_contig_kernel>( \
356
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res1_p, \
357
res1_offset, res2_p, res2_offset, depends); \
358
} \
359
\
360
template <typename fnT, typename T1, typename T2> \
361
struct ContigFactory \
362
{ \
363
fnT get() \
364
{ \
365
if constexpr (std::is_same_v< \
366
typename OutputType<T1, T2>::value_type1, \
367
void> || \
368
std::is_same_v< \
369
typename OutputType<T1, T2>::value_type2, \
370
void>) { \
371
\
372
fnT fn = nullptr; \
373
return fn; \
374
} \
375
else { \
376
fnT fn = __name__##_contig_impl<T1, T2>; \
377
return fn; \
378
} \
379
} \
380
}; \
381
\
382
template <typename fnT, typename T1, typename T2> \
383
struct TypeMapFactory \
384
{ \
385
std::enable_if_t<std::is_same<fnT, std::pair<int, int>>::value, \
386
std::pair<int, int>> \
387
get() \
388
{ \
389
using rT1 = typename OutputType<T1, T2>::value_type1; \
390
using rT2 = typename OutputType<T1, T2>::value_type2; \
391
return std::make_pair(td_ns::GetTypeid<rT1>{}.get(), \
392
td_ns::GetTypeid<rT2>{}.get()); \
393
} \
394
}; \
395
\
396
template <typename T1, typename T2, typename resT1, typename resT2, \
397
typename IndexerT> \
398
class __name__##_strided_kernel; \
399
\
400
template <typename argTy1, typename argTy2> \
401
sycl::event __name__##_strided_impl( \
402
sycl::queue &exec_q, size_t nelems, int nd, \
403
const py::ssize_t *shape_and_strides, const char *arg1_p, \
404
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
405
char *res1_p, py::ssize_t res1_offset, char *res2_p, \
406
py::ssize_t res2_offset, const std::vector<sycl::event> &depends, \
407
const std::vector<sycl::event> &additional_depends) \
408
{ \
409
return ew_cmn_ns::binary_two_outputs_strided_impl< \
410
argTy1, argTy2, OutputType, StridedFunctor, \
411
__name__##_strided_kernel>( \
412
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
413
arg2_p, arg2_offset, res1_p, res1_offset, res2_p, res2_offset, \
414
depends, additional_depends); \
415
} \
416
\
417
template <typename fnT, typename T1, typename T2> \
418
struct StridedFactory \
419
{ \
420
fnT get() \
421
{ \
422
if constexpr (std::is_same_v< \
423
typename OutputType<T1, T2>::value_type1, \
424
void> || \
425
std::is_same_v< \
426
typename OutputType<T1, T2>::value_type2, \
427
void>) { \
428
fnT fn = nullptr; \
429
return fn; \
430
} \
431
else { \
432
fnT fn = __name__##_strided_impl<T1, T2>; \
433
return fn; \
434
} \
435
} \
436
}; \
437
\
438
void populate_##__name__##_dispatch_tables(void) \
439
{ \
440
ext_ns::init_dispatch_table<binary_two_outputs_contig_impl_fn_ptr_t, \
441
ContigFactory>( \
442
__name__##_contig_dispatch_table); \
443
ext_ns::init_dispatch_table<binary_two_outputs_strided_impl_fn_ptr_t, \
444
StridedFactory>( \
445
__name__##_strided_dispatch_table); \
446
ext_ns::init_dispatch_table<std::pair<int, int>, TypeMapFactory>( \
447
__name__##_output_typeid_table); \
448
};
extensions
ufunc
elementwise_functions
populate.hpp
Generated by
1.12.0