DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31// utils extension header
32#include "ext/common.hpp"
33
34namespace ext_ns = ext::common;
35
40#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
41 template <typename T1, typename T2, unsigned int vec_sz, \
42 unsigned int n_vecs> \
43 class __name__##_contig_kernel; \
44 \
45 template <typename argTy> \
46 sycl::event __name__##_contig_impl( \
47 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
48 const std::vector<sycl::event> &depends = {}) \
49 { \
50 return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
51 __name__##_contig_kernel>( \
52 exec_q, nelems, arg_p, res_p, depends); \
53 } \
54 \
55 template <typename fnT, typename T> \
56 struct ContigFactory \
57 { \
58 fnT get() \
59 { \
60 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
61 void>) { \
62 fnT fn = nullptr; \
63 return fn; \
64 } \
65 else { \
66 fnT fn = __name__##_contig_impl<T>; \
67 return fn; \
68 } \
69 } \
70 }; \
71 \
72 template <typename fnT, typename T> \
73 struct TypeMapFactory \
74 { \
75 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
76 { \
77 using rT = typename OutputType<T>::value_type; \
78 return td_ns::GetTypeid<rT>{}.get(); \
79 } \
80 }; \
81 \
82 template <typename T1, typename T2, typename T3> \
83 class __name__##_strided_kernel; \
84 \
85 template <typename argTy> \
86 sycl::event __name__##_strided_impl( \
87 sycl::queue &exec_q, size_t nelems, int nd, \
88 const py::ssize_t *shape_and_strides, const char *arg_p, \
89 py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
90 const std::vector<sycl::event> &depends, \
91 const std::vector<sycl::event> &additional_depends) \
92 { \
93 return ew_cmn_ns::unary_strided_impl< \
94 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
95 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
96 res_offset, depends, additional_depends); \
97 } \
98 \
99 template <typename fnT, typename T> \
100 struct StridedFactory \
101 { \
102 fnT get() \
103 { \
104 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
105 void>) { \
106 fnT fn = nullptr; \
107 return fn; \
108 } \
109 else { \
110 fnT fn = __name__##_strided_impl<T>; \
111 return fn; \
112 } \
113 } \
114 }; \
115 \
116 void populate_##__name__##_dispatch_vectors(void) \
117 { \
118 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
119 ContigFactory>( \
120 __name__##_contig_dispatch_vector); \
121 ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
122 StridedFactory>( \
123 __name__##_strided_dispatch_vector); \
124 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
125 __name__##_output_typeid_vector); \
126 };
127
132#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
133 template <typename argT1, typename argT2, typename resT, \
134 unsigned int vec_sz, unsigned int n_vecs> \
135 class __name__##_contig_kernel; \
136 \
137 template <typename argTy1, typename argTy2> \
138 sycl::event __name__##_contig_impl( \
139 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
140 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
141 char *res_p, py::ssize_t res_offset, \
142 const std::vector<sycl::event> &depends = {}) \
143 { \
144 return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
145 ContigFunctor, \
146 __name__##_contig_kernel>( \
147 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
148 res_offset, depends); \
149 } \
150 \
151 template <typename fnT, typename T1, typename T2> \
152 struct ContigFactory \
153 { \
154 fnT get() \
155 { \
156 if constexpr (std::is_same_v< \
157 typename OutputType<T1, T2>::value_type, void>) \
158 { \
159 \
160 fnT fn = nullptr; \
161 return fn; \
162 } \
163 else { \
164 fnT fn = __name__##_contig_impl<T1, T2>; \
165 return fn; \
166 } \
167 } \
168 }; \
169 \
170 template <typename fnT, typename T1, typename T2> \
171 struct TypeMapFactory \
172 { \
173 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
174 { \
175 using rT = typename OutputType<T1, T2>::value_type; \
176 return td_ns::GetTypeid<rT>{}.get(); \
177 } \
178 }; \
179 \
180 template <typename T1, typename T2, typename resT, typename IndexerT> \
181 class __name__##_strided_kernel; \
182 \
183 template <typename argTy1, typename argTy2> \
184 sycl::event __name__##_strided_impl( \
185 sycl::queue &exec_q, size_t nelems, int nd, \
186 const py::ssize_t *shape_and_strides, const char *arg1_p, \
187 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
188 char *res_p, py::ssize_t res_offset, \
189 const std::vector<sycl::event> &depends, \
190 const std::vector<sycl::event> &additional_depends) \
191 { \
192 return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
193 StridedFunctor, \
194 __name__##_strided_kernel>( \
195 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
196 arg2_p, arg2_offset, res_p, res_offset, depends, \
197 additional_depends); \
198 } \
199 \
200 template <typename fnT, typename T1, typename T2> \
201 struct StridedFactory \
202 { \
203 fnT get() \
204 { \
205 if constexpr (std::is_same_v< \
206 typename OutputType<T1, T2>::value_type, void>) \
207 { \
208 fnT fn = nullptr; \
209 return fn; \
210 } \
211 else { \
212 fnT fn = __name__##_strided_impl<T1, T2>; \
213 return fn; \
214 } \
215 } \
216 }; \
217 \
218 void populate_##__name__##_dispatch_tables(void) \
219 { \
220 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
221 ContigFactory>( \
222 __name__##_contig_dispatch_table); \
223 ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
224 StridedFactory>( \
225 __name__##_strided_dispatch_table); \
226 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
227 __name__##_output_typeid_table); \
228 };