DPNP C++ backend kernel library 0.19.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
populate.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28// utils extension header
29#include "ext/common.hpp"
30
31namespace ext_ns = ext::common;
32
37#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
38 template <typename T1, typename T2, unsigned int vec_sz, \
39 unsigned int n_vecs> \
40 class __name__##_contig_kernel; \
41 \
42 template <typename argTy> \
43 sycl::event __name__##_contig_impl( \
44 sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, \
45 const std::vector<sycl::event> &depends = {}) \
46 { \
47 return ew_cmn_ns::unary_contig_impl<argTy, OutputType, ContigFunctor, \
48 __name__##_contig_kernel>( \
49 exec_q, nelems, arg_p, res_p, depends); \
50 } \
51 \
52 template <typename fnT, typename T> \
53 struct ContigFactory \
54 { \
55 fnT get() \
56 { \
57 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
58 void>) { \
59 fnT fn = nullptr; \
60 return fn; \
61 } \
62 else { \
63 fnT fn = __name__##_contig_impl<T>; \
64 return fn; \
65 } \
66 } \
67 }; \
68 \
69 template <typename fnT, typename T> \
70 struct TypeMapFactory \
71 { \
72 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
73 { \
74 using rT = typename OutputType<T>::value_type; \
75 return td_ns::GetTypeid<rT>{}.get(); \
76 } \
77 }; \
78 \
79 template <typename T1, typename T2, typename T3> \
80 class __name__##_strided_kernel; \
81 \
82 template <typename argTy> \
83 sycl::event __name__##_strided_impl( \
84 sycl::queue &exec_q, size_t nelems, int nd, \
85 const py::ssize_t *shape_and_strides, const char *arg_p, \
86 py::ssize_t arg_offset, char *res_p, py::ssize_t res_offset, \
87 const std::vector<sycl::event> &depends, \
88 const std::vector<sycl::event> &additional_depends) \
89 { \
90 return ew_cmn_ns::unary_strided_impl< \
91 argTy, OutputType, StridedFunctor, __name__##_strided_kernel>( \
92 exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, \
93 res_offset, depends, additional_depends); \
94 } \
95 \
96 template <typename fnT, typename T> \
97 struct StridedFactory \
98 { \
99 fnT get() \
100 { \
101 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
102 void>) { \
103 fnT fn = nullptr; \
104 return fn; \
105 } \
106 else { \
107 fnT fn = __name__##_strided_impl<T>; \
108 return fn; \
109 } \
110 } \
111 }; \
112 \
113 void populate_##__name__##_dispatch_vectors(void) \
114 { \
115 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
116 ContigFactory>( \
117 __name__##_contig_dispatch_vector); \
118 ext_ns::init_dispatch_vector<unary_strided_impl_fn_ptr_t, \
119 StridedFactory>( \
120 __name__##_strided_dispatch_vector); \
121 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
122 __name__##_output_typeid_vector); \
123 };
124
129#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
130 template <typename argT1, typename argT2, typename resT, \
131 unsigned int vec_sz, unsigned int n_vecs> \
132 class __name__##_contig_kernel; \
133 \
134 template <typename argTy1, typename argTy2> \
135 sycl::event __name__##_contig_impl( \
136 sycl::queue &exec_q, size_t nelems, const char *arg1_p, \
137 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
138 char *res_p, py::ssize_t res_offset, \
139 const std::vector<sycl::event> &depends = {}) \
140 { \
141 return ew_cmn_ns::binary_contig_impl<argTy1, argTy2, OutputType, \
142 ContigFunctor, \
143 __name__##_contig_kernel>( \
144 exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, \
145 res_offset, depends); \
146 } \
147 \
148 template <typename fnT, typename T1, typename T2> \
149 struct ContigFactory \
150 { \
151 fnT get() \
152 { \
153 if constexpr (std::is_same_v< \
154 typename OutputType<T1, T2>::value_type, void>) \
155 { \
156 \
157 fnT fn = nullptr; \
158 return fn; \
159 } \
160 else { \
161 fnT fn = __name__##_contig_impl<T1, T2>; \
162 return fn; \
163 } \
164 } \
165 }; \
166 \
167 template <typename fnT, typename T1, typename T2> \
168 struct TypeMapFactory \
169 { \
170 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
171 { \
172 using rT = typename OutputType<T1, T2>::value_type; \
173 return td_ns::GetTypeid<rT>{}.get(); \
174 } \
175 }; \
176 \
177 template <typename T1, typename T2, typename resT, typename IndexerT> \
178 class __name__##_strided_kernel; \
179 \
180 template <typename argTy1, typename argTy2> \
181 sycl::event __name__##_strided_impl( \
182 sycl::queue &exec_q, size_t nelems, int nd, \
183 const py::ssize_t *shape_and_strides, const char *arg1_p, \
184 py::ssize_t arg1_offset, const char *arg2_p, py::ssize_t arg2_offset, \
185 char *res_p, py::ssize_t res_offset, \
186 const std::vector<sycl::event> &depends, \
187 const std::vector<sycl::event> &additional_depends) \
188 { \
189 return ew_cmn_ns::binary_strided_impl<argTy1, argTy2, OutputType, \
190 StridedFunctor, \
191 __name__##_strided_kernel>( \
192 exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, \
193 arg2_p, arg2_offset, res_p, res_offset, depends, \
194 additional_depends); \
195 } \
196 \
197 template <typename fnT, typename T1, typename T2> \
198 struct StridedFactory \
199 { \
200 fnT get() \
201 { \
202 if constexpr (std::is_same_v< \
203 typename OutputType<T1, T2>::value_type, void>) \
204 { \
205 fnT fn = nullptr; \
206 return fn; \
207 } \
208 else { \
209 fnT fn = __name__##_strided_impl<T1, T2>; \
210 return fn; \
211 } \
212 } \
213 }; \
214 \
215 void populate_##__name__##_dispatch_tables(void) \
216 { \
217 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
218 ContigFactory>( \
219 __name__##_contig_dispatch_table); \
220 ext_ns::init_dispatch_table<binary_strided_impl_fn_ptr_t, \
221 StridedFactory>( \
222 __name__##_strided_dispatch_table); \
223 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
224 __name__##_output_typeid_table); \
225 };