DPNP C++ backend kernel library
0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
// this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
// this list of conditions and the following disclaimer in the documentation
11
// and/or other materials provided with the distribution.
12
// - Neither the name of the copyright holder nor the names of its contributors
13
// may be used to endorse or promote products derived from this software
14
// without specific prior written permission.
15
//
16
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26
// THE POSSIBILITY OF SUCH DAMAGE.
27
//*****************************************************************************
28
29
#pragma once
30
31
// utils extension header
32
#include "ext/common.hpp"
33
34
namespace
ext_ns = ext::common;
35
40
#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
41
template <typename T1, typename T2, unsigned int vec_sz, \
42
unsigned int n_vecs> \
43
class __name__##_contig_kernel; \
44
\
45
template <typename argTy> \
46
sycl::event __name__##_contig_impl( \
47
sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
48
const std::vector<sycl::event> &depends = {}) \
49
{ \
50
return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
51
__name__##_contig_kernel>( \
52
exec_q, nelems, arg_p, res_p, depends); \
53
} \
54
\
55
template <typename fnT, typename T> \
56
struct ContigFactory \
57
{ \
58
fnT get() \
59
{ \
60
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
61
void>) { \
62
fnT fn = nullptr; \
63
return fn; \
64
} \
65
else { \
66
fnT fn = __name__##_contig_impl<T>; \
67
return fn; \
68
} \
69
} \
70
}; \
71
\
72
template <typename fnT, typename T> \
73
struct TypeMapFactory \
74
{ \
75
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
76
{ \
77
using rT = typename OutputType<T>::value_type; \
78
return td_ns::GetTypeid<rT>{}.get(); \
79
} \
80
}; \
81
\
82
template <typename T1, typename T2, typename T3> \
83
class __name__##_strided_kernel; \
84
\
85
template <typename argTy> \
86
sycl::event __name__##_strided_impl( \
87
sycl::queue &exec_q, size_t nelems, int nd, \
88
const py::ssize_t *shape_and_strides, const char *arg_p, \
89
py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
90
const std::vector<sycl::event> &depends, \
91
const std::vector<sycl::event> &additional_depends) \
92
{ \
93
return ew_cmn_ns::unary_strided_impl< \
94
argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
95
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
96
res_offset, depends, additional_depends); \
97
} \
98
\
99
template <typename fnT, typename T> \
100
struct StridedFactory \
101
{ \
102
fnT get() \
103
{ \
104
if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
105
void>) { \
106
fnT fn = nullptr; \
107
return fn; \
108
} \
109
else { \
110
fnT fn = __name__##_strided_impl<T>; \
111
return fn; \
112
} \
113
} \
114
}; \
115
\
116
void populate_##__name__##_dispatch_vectors(void) \
117
{ \
118
ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
119
ContigFactory>( \
120
__name__##_contig_dispatch_vector); \
121
ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
122
StridedFactory>( \
123
__name__##_strided_dispatch_vector); \
124
ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
125
__name__##_output_typeid_vector); \
126
};
127
132
#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
133
template <typename argT1, typename argT2, typename resT, \
134
unsigned int vec_sz, unsigned int n_vecs> \
135
class __name__##_contig_kernel; \
136
\
137
template <typename argTy1, typename argTy2> \
138
sycl::event __name__##_contig_impl( \
139
sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
140
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
141
char *res_p, py::ssize_t res_offset, \
142
const std::vector<sycl::event> &depends = {}) \
143
{ \
144
return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
145
ContigFunctor, \
146
__name__##_contig_kernel>( \
147
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
148
res_offset, depends); \
149
} \
150
\
151
template <typename fnT, typename T1, typename T2> \
152
struct ContigFactory \
153
{ \
154
fnT get() \
155
{ \
156
if constexpr (std::is_same_v< \
157
typename OutputType<T1, T2>::value_type, void>) \
158
{ \
159
\
160
fnT fn = nullptr; \
161
return fn; \
162
} \
163
else { \
164
fnT fn = __name__##_contig_impl<T1, T2>; \
165
return fn; \
166
} \
167
} \
168
}; \
169
\
170
template <typename fnT, typename T1, typename T2> \
171
struct TypeMapFactory \
172
{ \
173
std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
174
{ \
175
using rT = typename OutputType<T1, T2>::value_type; \
176
return td_ns::GetTypeid<rT>{}.get(); \
177
} \
178
}; \
179
\
180
template <typename T1, typename T2, typename resT, typename IndexerT> \
181
class __name__##_strided_kernel; \
182
\
183
template <typename argTy1, typename argTy2> \
184
sycl::event __name__##_strided_impl( \
185
sycl::queue &exec_q, size_t nelems, int nd, \
186
const py::ssize_t *shape_and_strides, const char *arg1_p, \
187
py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
188
char *res_p, py::ssize_t res_offset, \
189
const std::vector<sycl::event> &depends, \
190
const std::vector<sycl::event> &additional_depends) \
191
{ \
192
return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
193
StridedFunctor, \
194
__name__##_strided_kernel>( \
195
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
196
arg2_p, arg2_offset, res_p, res_offset, depends, \
197
additional_depends); \
198
} \
199
\
200
template <typename fnT, typename T1, typename T2> \
201
struct StridedFactory \
202
{ \
203
fnT get() \
204
{ \
205
if constexpr (std::is_same_v< \
206
typename OutputType<T1, T2>::value_type, void>) \
207
{ \
208
fnT fn = nullptr; \
209
return fn; \
210
} \
211
else { \
212
fnT fn = __name__##_strided_impl<T1, T2>; \
213
return fn; \
214
} \
215
} \
216
}; \
217
\
218
void populate_##__name__##_dispatch_tables(void) \
219
{ \
220
ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
221
ContigFactory>( \
222
__name__##_contig_dispatch_table); \
223
ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
224
StridedFactory>( \
225
__name__##_strided_dispatch_table); \
226
ext_ns::init_dispatch_table<int, TypeMapFactory>( \
227
__name__##_output_typeid_table); \
228
};
extensions
ufunc
elementwise_functions
populate.hpp
Generated by
1.12.0