DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dispatch_table.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <unordered_set>
32#include <vector>
33
34#include "utils/type_dispatch.hpp"
35#include <pybind11/numpy.h>
36#include <pybind11/pybind11.h>
37#include <pybind11/stl.h>
38#include <sycl/sycl.hpp>
39
40#include "ext/common.hpp"
41
42namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
43namespace py = pybind11;
44
45namespace ext::common
46{
47template <typename T, typename Rest>
48struct one_of
49{
50 static_assert(std::is_same_v<Rest, std::tuple<>>,
51 "one_of: second parameter cannot be empty std::tuple");
52 static_assert(false, "one_of: second parameter must be std::tuple");
53};
54
55template <typename T, typename Top, typename... Rest>
56struct one_of<T, std::tuple<Top, Rest...>>
57{
58 static constexpr bool value =
59 std::is_same_v<T, Top> || one_of<T, std::tuple<Rest...>>::value;
60};
61
62template <typename T, typename Top>
63struct one_of<T, std::tuple<Top>>
64{
65 static constexpr bool value = std::is_same_v<T, Top>;
66};
67
68template <typename T, typename Rest>
69constexpr bool one_of_v = one_of<T, Rest>::value;
70
71template <typename FnT>
72using Table = FnT[dpctl_td_ns::num_types];
73template <typename FnT>
74using Table2 = Table<FnT>[dpctl_td_ns::num_types];
75
76using TypeId = int32_t;
77using TypesPair = std::pair<TypeId, TypeId>;
78
80{
81 inline size_t operator()(const TypesPair &p) const
82 {
83 std::hash<size_t> hasher;
84 return hasher(size_t(p.first) << (8 * sizeof(TypeId)) |
85 size_t(p.second));
86 }
87};
88
89using SupportedTypesList = std::vector<TypeId>;
90using SupportedTypesList2 = std::vector<TypesPair>;
91using SupportedTypesSet = std::unordered_set<TypeId>;
92using SupportedTypesSet2 = std::unordered_set<TypesPair, int_pair_hash>;
93
94using DType = py::dtype;
95using DTypePair = std::pair<DType, DType>;
96
97using SupportedDTypeList = std::vector<DType>;
98using SupportedDTypeList2 = std::vector<DTypePair>;
99
100template <typename FnT,
101 typename SupportedTypes,
102 template <typename> typename Func>
104{
105 template <typename _FnT, typename T>
106 struct impl
107 {
108 static constexpr bool is_defined = one_of_v<T, SupportedTypes>;
109
110 _FnT get()
111 {
112 if constexpr (is_defined) {
113 return Func<T>::impl;
114 }
115 else {
116 return nullptr;
117 }
118 }
119 };
120
121 using type =
122 dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
123};
124
125template <typename FnT,
126 typename SupportedTypes,
127 template <typename, typename> typename Func>
129{
130 template <typename _FnT, typename T1, typename T2>
131 struct impl
132 {
133 static constexpr bool is_defined =
134 one_of_v<std::tuple<T1, T2>, SupportedTypes>;
135
136 _FnT get()
137 {
138 if constexpr (is_defined) {
139 return Func<T1, T2>::impl;
140 }
141 else {
142 return nullptr;
143 }
144 }
145 };
146
147 using type =
148 dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
149};
150
151template <typename FnT>
153{
154public:
155 DispatchTable(std::string name) : name(name) {}
156
157 template <typename SupportedTypes, template <typename> typename Func>
158 void populate_dispatch_table()
159 {
160 using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
161 TBulder builder;
162
163 builder.populate_dispatch_vector(table);
164 populate_supported_types();
165 }
166
167 FnT get_unsafe(int _typenum) const
168 {
169 auto array_types = dpctl_td_ns::usm_ndarray_types();
170 const int type_id = array_types.typenum_to_lookup_id(_typenum);
171
172 return table[type_id];
173 }
174
175 FnT get(int _typenum) const
176 {
177 auto fn = get_unsafe(_typenum);
178
179 if (fn == nullptr) {
180 auto array_types = dpctl_td_ns::usm_ndarray_types();
181 const int _type_id = array_types.typenum_to_lookup_id(_typenum);
182
183 py::dtype _dtype = dtype_from_typenum(_type_id);
184 auto _type_pos = std::find(supported_types.begin(),
185 supported_types.end(), _dtype);
186 if (_type_pos == supported_types.end()) {
187 py::str types = py::str(py::cast(supported_types));
188 py::str dtype = py::str(_dtype);
189
190 py::str err_msg =
191 py::str("'" + name + "' has unsupported type '") + dtype +
192 py::str("'."
193 " Supported types are: ") +
194 types;
195
196 throw py::value_error(static_cast<std::string>(err_msg));
197 }
198 }
199
200 return fn;
201 }
202
203 const SupportedDTypeList &get_all_supported_types() const
204 {
205 return supported_types;
206 }
207
208private:
209 void populate_supported_types()
210 {
211 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
212 if (table[i] != nullptr) {
213 supported_types.emplace_back(dtype_from_typenum(i));
214 }
215 }
216 }
217
218 std::string name;
219 SupportedDTypeList supported_types;
220 Table<FnT> table;
221};
222
223template <typename FnT>
225{
226public:
227 DispatchTable2(std::string first_name, std::string second_name)
228 : first_name(first_name), second_name(second_name)
229 {
230 }
231
232 template <typename SupportedTypes,
233 template <typename, typename> typename Func>
234 void populate_dispatch_table()
235 {
236 using TBulder = typename TableBuilder2<FnT, SupportedTypes, Func>::type;
237 TBulder builder;
238
239 builder.populate_dispatch_table(table);
240 populate_supported_types();
241 }
242
243 FnT get_unsafe(int first_typenum, int second_typenum) const
244 {
245 auto array_types = dpctl_td_ns::usm_ndarray_types();
246 const int first_type_id =
247 array_types.typenum_to_lookup_id(first_typenum);
248 const int second_type_id =
249 array_types.typenum_to_lookup_id(second_typenum);
250
251 return table[first_type_id][second_type_id];
252 }
253
254 FnT get(int first_typenum, int second_typenum) const
255 {
256 auto fn = get_unsafe(first_typenum, second_typenum);
257
258 if (fn == nullptr) {
259 auto array_types = dpctl_td_ns::usm_ndarray_types();
260 const int first_type_id =
261 array_types.typenum_to_lookup_id(first_typenum);
262 const int second_type_id =
263 array_types.typenum_to_lookup_id(second_typenum);
264
265 py::dtype first_dtype = dtype_from_typenum(first_type_id);
266 auto first_type_pos =
267 std::find(supported_first_type.begin(),
268 supported_first_type.end(), first_dtype);
269 if (first_type_pos == supported_first_type.end()) {
270 py::str types = py::str(py::cast(supported_first_type));
271 py::str dtype = py::str(first_dtype);
272
273 py::str err_msg =
274 py::str("'" + first_name + "' has unsupported type '") +
275 dtype +
276 py::str("'."
277 " Supported types are: ") +
278 types;
279
280 throw py::value_error(static_cast<std::string>(err_msg));
281 }
282
283 py::dtype second_dtype = dtype_from_typenum(second_type_id);
284 auto second_type_pos =
285 std::find(supported_second_type.begin(),
286 supported_second_type.end(), second_dtype);
287 if (second_type_pos == supported_second_type.end()) {
288 py::str types = py::str(py::cast(supported_second_type));
289 py::str dtype = py::str(second_dtype);
290
291 py::str err_msg =
292 py::str("'" + second_name + "' has unsupported type '") +
293 dtype +
294 py::str("'."
295 " Supported types are: ") +
296 types;
297
298 throw py::value_error(static_cast<std::string>(err_msg));
299 }
300
301 py::str first_dtype_str = py::str(first_dtype);
302 py::str second_dtype_str = py::str(second_dtype);
303 py::str types = py::str(py::cast(all_supported_types));
304
305 py::str err_msg =
306 py::str("'" + first_name + "' and '" + second_name +
307 "' has unsupported types combination: ('") +
308 first_dtype_str + py::str("', '") + second_dtype_str +
309 py::str("')."
310 " Supported types combinations are: ") +
311 types;
312
313 throw py::value_error(static_cast<std::string>(err_msg));
314 }
315
316 return fn;
317 }
318
319 const SupportedDTypeList &get_supported_first_type() const
320 {
321 return supported_first_type;
322 }
323
324 const SupportedDTypeList &get_supported_second_type() const
325 {
326 return supported_second_type;
327 }
328
329 const SupportedDTypeList2 &get_all_supported_types() const
330 {
331 return all_supported_types;
332 }
333
334private:
335 void populate_supported_types()
336 {
337 SupportedTypesSet first_supported_types_set;
338 SupportedTypesSet second_supported_types_set;
339 SupportedTypesSet2 all_supported_types_set;
340
341 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
342 for (int j = 0; j < dpctl_td_ns::num_types; ++j) {
343 if (table[i][j] != nullptr) {
344 all_supported_types_set.emplace(i, j);
345 first_supported_types_set.emplace(i);
346 second_supported_types_set.emplace(j);
347 }
348 }
349 }
350
351 auto to_supported_dtype_list = [](const auto &supported_set,
352 auto &supported_list) {
353 SupportedTypesList lst(supported_set.begin(), supported_set.end());
354 std::sort(lst.begin(), lst.end());
355 supported_list.resize(supported_set.size());
356 std::transform(lst.begin(), lst.end(), supported_list.begin(),
357 [](TypeId i) { return dtype_from_typenum(i); });
358 };
359
360 to_supported_dtype_list(first_supported_types_set,
361 supported_first_type);
362 to_supported_dtype_list(second_supported_types_set,
363 supported_second_type);
364
365 SupportedTypesList2 lst(all_supported_types_set.begin(),
366 all_supported_types_set.end());
367 std::sort(lst.begin(), lst.end());
368 all_supported_types.resize(all_supported_types_set.size());
369 std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
370 [](TypesPair p) {
371 return DTypePair(dtype_from_typenum(p.first),
372 dtype_from_typenum(p.second));
373 });
374 }
375
376 std::string first_name;
377 std::string second_name;
378
379 SupportedDTypeList supported_first_type;
380 SupportedDTypeList supported_second_type;
381 SupportedDTypeList2 all_supported_types;
382
383 Table2<FnT> table;
384};
385
386} // namespace ext::common