DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dispatch_table.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <unordered_set>
32#include <vector>
33
34#include "utils/type_dispatch.hpp"
35#include <pybind11/numpy.h>
36#include <pybind11/pybind11.h>
37#include <pybind11/stl.h>
38#include <sycl/sycl.hpp>
39
40#include "ext/common.hpp"
41
42namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
43namespace py = pybind11;
44
45namespace ext::common
46{
47template <typename T, typename Rest>
48struct one_of
49{
50 static_assert(std::is_same_v<Rest, std::tuple<>>,
51 "one_of: second parameter cannot be empty std::tuple");
52 static_assert(false, "one_of: second parameter must be std::tuple");
53};
54
55template <typename T, typename Top, typename... Rest>
56struct one_of<T, std::tuple<Top, Rest...>>
57{
58 static constexpr bool value =
59 std::is_same_v<T, Top> || one_of<T, std::tuple<Rest...>>::value;
60};
61
62template <typename T, typename Top>
63struct one_of<T, std::tuple<Top>>
64{
65 static constexpr bool value = std::is_same_v<T, Top>;
66};
67
68template <typename T, typename Rest>
69constexpr bool one_of_v = one_of<T, Rest>::value;
70
71template <typename FnT>
72using Table = FnT[dpctl_td_ns::num_types];
73template <typename FnT>
74using Table2 = Table<FnT>[dpctl_td_ns::num_types];
75
76using TypeId = int32_t;
77using TypesPair = std::pair<TypeId, TypeId>;
78
80{
81 inline size_t operator()(const TypesPair &p) const
82 {
83 std::hash<size_t> hasher;
84 return hasher(size_t(p.first) << (8 * sizeof(TypeId)) |
85 size_t(p.second));
86 }
87};
88
89using SupportedTypesList = std::vector<TypeId>;
90using SupportedTypesList2 = std::vector<TypesPair>;
91using SupportedTypesSet = std::unordered_set<TypeId>;
92using SupportedTypesSet2 = std::unordered_set<TypesPair, int_pair_hash>;
93
94using DType = py::dtype;
95using DTypePair = std::pair<DType, DType>;
96
97using SupportedDTypeList = std::vector<DType>;
98using SupportedDTypeList2 = std::vector<DTypePair>;
99
100template <typename FnT,
101 typename SupportedTypes,
102 template <typename>
103 typename Func>
105{
106 template <typename _FnT, typename T>
107 struct impl
108 {
109 static constexpr bool is_defined = one_of_v<T, SupportedTypes>;
110
111 _FnT get()
112 {
113 if constexpr (is_defined) {
114 return Func<T>::impl;
115 }
116 else {
117 return nullptr;
118 }
119 }
120 };
121
122 using type =
123 dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
124};
125
126template <typename FnT,
127 typename SupportedTypes,
128 template <typename, typename>
129 typename Func>
131{
132 template <typename _FnT, typename T1, typename T2>
133 struct impl
134 {
135 static constexpr bool is_defined =
136 one_of_v<std::tuple<T1, T2>, SupportedTypes>;
137
138 _FnT get()
139 {
140 if constexpr (is_defined) {
141 return Func<T1, T2>::impl;
142 }
143 else {
144 return nullptr;
145 }
146 }
147 };
148
149 using type =
150 dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
151};
152
153template <typename FnT>
155{
156public:
157 DispatchTable(std::string name) : name(name) {}
158
159 template <typename SupportedTypes, template <typename> typename Func>
160 void populate_dispatch_table()
161 {
162 using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
163 TBulder builder;
164
165 builder.populate_dispatch_vector(table);
166 populate_supported_types();
167 }
168
169 FnT get_unsafe(int _typenum) const
170 {
171 auto array_types = dpctl_td_ns::usm_ndarray_types();
172 const int type_id = array_types.typenum_to_lookup_id(_typenum);
173
174 return table[type_id];
175 }
176
177 FnT get(int _typenum) const
178 {
179 auto fn = get_unsafe(_typenum);
180
181 if (fn == nullptr) {
182 auto array_types = dpctl_td_ns::usm_ndarray_types();
183 const int _type_id = array_types.typenum_to_lookup_id(_typenum);
184
185 py::dtype _dtype = dtype_from_typenum(_type_id);
186 auto _type_pos = std::find(supported_types.begin(),
187 supported_types.end(), _dtype);
188 if (_type_pos == supported_types.end()) {
189 py::str types = py::str(py::cast(supported_types));
190 py::str dtype = py::str(_dtype);
191
192 py::str err_msg =
193 py::str("'" + name + "' has unsupported type '") + dtype +
194 py::str("'."
195 " Supported types are: ") +
196 types;
197
198 throw py::value_error(static_cast<std::string>(err_msg));
199 }
200 }
201
202 return fn;
203 }
204
205 const SupportedDTypeList &get_all_supported_types() const
206 {
207 return supported_types;
208 }
209
210private:
211 void populate_supported_types()
212 {
213 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
214 if (table[i] != nullptr) {
215 supported_types.emplace_back(dtype_from_typenum(i));
216 }
217 }
218 }
219
220 std::string name;
221 SupportedDTypeList supported_types;
222 Table<FnT> table;
223};
224
225template <typename FnT>
227{
228public:
229 DispatchTable2(std::string first_name, std::string second_name)
230 : first_name(first_name), second_name(second_name)
231 {
232 }
233
234 template <typename SupportedTypes,
235 template <typename, typename>
236 typename Func>
237 void populate_dispatch_table()
238 {
239 using TBulder = typename TableBuilder2<FnT, SupportedTypes, Func>::type;
240 TBulder builder;
241
242 builder.populate_dispatch_table(table);
243 populate_supported_types();
244 }
245
246 FnT get_unsafe(int first_typenum, int second_typenum) const
247 {
248 auto array_types = dpctl_td_ns::usm_ndarray_types();
249 const int first_type_id =
250 array_types.typenum_to_lookup_id(first_typenum);
251 const int second_type_id =
252 array_types.typenum_to_lookup_id(second_typenum);
253
254 return table[first_type_id][second_type_id];
255 }
256
257 FnT get(int first_typenum, int second_typenum) const
258 {
259 auto fn = get_unsafe(first_typenum, second_typenum);
260
261 if (fn == nullptr) {
262 auto array_types = dpctl_td_ns::usm_ndarray_types();
263 const int first_type_id =
264 array_types.typenum_to_lookup_id(first_typenum);
265 const int second_type_id =
266 array_types.typenum_to_lookup_id(second_typenum);
267
268 py::dtype first_dtype = dtype_from_typenum(first_type_id);
269 auto first_type_pos =
270 std::find(supported_first_type.begin(),
271 supported_first_type.end(), first_dtype);
272 if (first_type_pos == supported_first_type.end()) {
273 py::str types = py::str(py::cast(supported_first_type));
274 py::str dtype = py::str(first_dtype);
275
276 py::str err_msg =
277 py::str("'" + first_name + "' has unsupported type '") +
278 dtype +
279 py::str("'."
280 " Supported types are: ") +
281 types;
282
283 throw py::value_error(static_cast<std::string>(err_msg));
284 }
285
286 py::dtype second_dtype = dtype_from_typenum(second_type_id);
287 auto second_type_pos =
288 std::find(supported_second_type.begin(),
289 supported_second_type.end(), second_dtype);
290 if (second_type_pos == supported_second_type.end()) {
291 py::str types = py::str(py::cast(supported_second_type));
292 py::str dtype = py::str(second_dtype);
293
294 py::str err_msg =
295 py::str("'" + second_name + "' has unsupported type '") +
296 dtype +
297 py::str("'."
298 " Supported types are: ") +
299 types;
300
301 throw py::value_error(static_cast<std::string>(err_msg));
302 }
303
304 py::str first_dtype_str = py::str(first_dtype);
305 py::str second_dtype_str = py::str(second_dtype);
306 py::str types = py::str(py::cast(all_supported_types));
307
308 py::str err_msg =
309 py::str("'" + first_name + "' and '" + second_name +
310 "' has unsupported types combination: ('") +
311 first_dtype_str + py::str("', '") + second_dtype_str +
312 py::str("')."
313 " Supported types combinations are: ") +
314 types;
315
316 throw py::value_error(static_cast<std::string>(err_msg));
317 }
318
319 return fn;
320 }
321
322 const SupportedDTypeList &get_supported_first_type() const
323 {
324 return supported_first_type;
325 }
326
327 const SupportedDTypeList &get_supported_second_type() const
328 {
329 return supported_second_type;
330 }
331
332 const SupportedDTypeList2 &get_all_supported_types() const
333 {
334 return all_supported_types;
335 }
336
337private:
338 void populate_supported_types()
339 {
340 SupportedTypesSet first_supported_types_set;
341 SupportedTypesSet second_supported_types_set;
342 SupportedTypesSet2 all_supported_types_set;
343
344 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
345 for (int j = 0; j < dpctl_td_ns::num_types; ++j) {
346 if (table[i][j] != nullptr) {
347 all_supported_types_set.emplace(i, j);
348 first_supported_types_set.emplace(i);
349 second_supported_types_set.emplace(j);
350 }
351 }
352 }
353
354 auto to_supported_dtype_list = [](const auto &supported_set,
355 auto &supported_list) {
356 SupportedTypesList lst(supported_set.begin(), supported_set.end());
357 std::sort(lst.begin(), lst.end());
358 supported_list.resize(supported_set.size());
359 std::transform(lst.begin(), lst.end(), supported_list.begin(),
360 [](TypeId i) { return dtype_from_typenum(i); });
361 };
362
363 to_supported_dtype_list(first_supported_types_set,
364 supported_first_type);
365 to_supported_dtype_list(second_supported_types_set,
366 supported_second_type);
367
368 SupportedTypesList2 lst(all_supported_types_set.begin(),
369 all_supported_types_set.end());
370 std::sort(lst.begin(), lst.end());
371 all_supported_types.resize(all_supported_types_set.size());
372 std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
373 [](TypesPair p) {
374 return DTypePair(dtype_from_typenum(p.first),
375 dtype_from_typenum(p.second));
376 });
377 }
378
379 std::string first_name;
380 std::string second_name;
381
382 SupportedDTypeList supported_first_type;
383 SupportedDTypeList supported_second_type;
384 SupportedDTypeList2 all_supported_types;
385
386 Table2<FnT> table;
387};
388
389} // namespace ext::common