DPNP C++ backend kernel library 0.18.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
dispatch_table.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <unordered_set>
29#include <vector>
30
31#include "utils/type_dispatch.hpp"
32#include <pybind11/numpy.h>
33#include <pybind11/pybind11.h>
34#include <pybind11/stl.h>
35#include <sycl/sycl.hpp>
36
37#include "common.hpp"
38
39namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
40namespace py = pybind11;
41
42namespace statistics::common
43{
44template <typename T, typename Rest>
45struct one_of
46{
47 static_assert(std::is_same_v<Rest, std::tuple<>>,
48 "one_of: second parameter cannot be empty std::tuple");
49 static_assert(false, "one_of: second parameter must be std::tuple");
50};
51
52template <typename T, typename Top, typename... Rest>
53struct one_of<T, std::tuple<Top, Rest...>>
54{
55 static constexpr bool value =
56 std::is_same_v<T, Top> || one_of<T, std::tuple<Rest...>>::value;
57};
58
59template <typename T, typename Top>
60struct one_of<T, std::tuple<Top>>
61{
62 static constexpr bool value = std::is_same_v<T, Top>;
63};
64
65template <typename T, typename Rest>
66constexpr bool one_of_v = one_of<T, Rest>::value;
67
68template <typename FnT>
69using Table = FnT[dpctl_td_ns::num_types];
70template <typename FnT>
71using Table2 = Table<FnT>[dpctl_td_ns::num_types];
72
73using TypeId = int32_t;
74using TypesPair = std::pair<TypeId, TypeId>;
75
77{
78 inline size_t operator()(const TypesPair &p) const
79 {
80 std::hash<size_t> hasher;
81 return hasher(size_t(p.first) << (8 * sizeof(TypeId)) |
82 size_t(p.second));
83 }
84};
85
86using SupportedTypesList = std::vector<TypeId>;
87using SupportedTypesList2 = std::vector<TypesPair>;
88using SupportedTypesSet = std::unordered_set<TypeId>;
89using SupportedTypesSet2 = std::unordered_set<TypesPair, int_pair_hash>;
90
91using DType = py::dtype;
92using DTypePair = std::pair<DType, DType>;
93
94using SupportedDTypeList = std::vector<DType>;
95using SupportedDTypeList2 = std::vector<DTypePair>;
96
97template <typename FnT,
98 typename SupportedTypes,
99 template <typename>
100 typename Func>
102{
103 template <typename _FnT, typename T>
104 struct impl
105 {
106 static constexpr bool is_defined = one_of_v<T, SupportedTypes>;
107
108 _FnT get()
109 {
110 if constexpr (is_defined) {
111 return Func<T>::impl;
112 }
113 else {
114 return nullptr;
115 }
116 }
117 };
118
119 using type =
120 dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
121};
122
123template <typename FnT,
124 typename SupportedTypes,
125 template <typename, typename>
126 typename Func>
128{
129 template <typename _FnT, typename T1, typename T2>
130 struct impl
131 {
132 static constexpr bool is_defined =
133 one_of_v<std::tuple<T1, T2>, SupportedTypes>;
134
135 _FnT get()
136 {
137 if constexpr (is_defined) {
138 return Func<T1, T2>::impl;
139 }
140 else {
141 return nullptr;
142 }
143 }
144 };
145
146 using type =
147 dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
148};
149
150template <typename FnT>
152{
153public:
154 DispatchTable(std::string name) : name(name) {}
155
156 template <typename SupportedTypes, template <typename> typename Func>
157 void populate_dispatch_table()
158 {
159 using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
160 TBulder builder;
161
162 builder.populate_dispatch_vector(table);
163 populate_supported_types();
164 }
165
166 FnT get_unsafe(int _typenum) const
167 {
168 auto array_types = dpctl_td_ns::usm_ndarray_types();
169 const int type_id = array_types.typenum_to_lookup_id(_typenum);
170
171 return table[type_id];
172 }
173
174 FnT get(int _typenum) const
175 {
176 auto fn = get_unsafe(_typenum);
177
178 if (fn == nullptr) {
179 auto array_types = dpctl_td_ns::usm_ndarray_types();
180 const int _type_id = array_types.typenum_to_lookup_id(_typenum);
181
182 py::dtype _dtype = dtype_from_typenum(_type_id);
183 auto _type_pos = std::find(supported_types.begin(),
184 supported_types.end(), _dtype);
185 if (_type_pos == supported_types.end()) {
186 py::str types = py::str(py::cast(supported_types));
187 py::str dtype = py::str(_dtype);
188
189 py::str err_msg =
190 py::str("'" + name + "' has unsupported type '") + dtype +
191 py::str("'."
192 " Supported types are: ") +
193 types;
194
195 throw py::value_error(static_cast<std::string>(err_msg));
196 }
197 }
198
199 return fn;
200 }
201
202 const SupportedDTypeList &get_all_supported_types() const
203 {
204 return supported_types;
205 }
206
207private:
208 void populate_supported_types()
209 {
210 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
211 if (table[i] != nullptr) {
212 supported_types.emplace_back(dtype_from_typenum(i));
213 }
214 }
215 }
216
217 std::string name;
218 SupportedDTypeList supported_types;
219 Table<FnT> table;
220};
221
222template <typename FnT>
224{
225public:
226 DispatchTable2(std::string first_name, std::string second_name)
227 : first_name(first_name), second_name(second_name)
228 {
229 }
230
231 template <typename SupportedTypes,
232 template <typename, typename>
233 typename Func>
234 void populate_dispatch_table()
235 {
236 using TBulder = typename TableBuilder2<FnT, SupportedTypes, Func>::type;
237 TBulder builder;
238
239 builder.populate_dispatch_table(table);
240 populate_supported_types();
241 }
242
243 FnT get_unsafe(int first_typenum, int second_typenum) const
244 {
245 auto array_types = dpctl_td_ns::usm_ndarray_types();
246 const int first_type_id =
247 array_types.typenum_to_lookup_id(first_typenum);
248 const int second_type_id =
249 array_types.typenum_to_lookup_id(second_typenum);
250
251 return table[first_type_id][second_type_id];
252 }
253
254 FnT get(int first_typenum, int second_typenum) const
255 {
256 auto fn = get_unsafe(first_typenum, second_typenum);
257
258 if (fn == nullptr) {
259 auto array_types = dpctl_td_ns::usm_ndarray_types();
260 const int first_type_id =
261 array_types.typenum_to_lookup_id(first_typenum);
262 const int second_type_id =
263 array_types.typenum_to_lookup_id(second_typenum);
264
265 py::dtype first_dtype = dtype_from_typenum(first_type_id);
266 auto first_type_pos =
267 std::find(supported_first_type.begin(),
268 supported_first_type.end(), first_dtype);
269 if (first_type_pos == supported_first_type.end()) {
270 py::str types = py::str(py::cast(supported_first_type));
271 py::str dtype = py::str(first_dtype);
272
273 py::str err_msg =
274 py::str("'" + first_name + "' has unsupported type '") +
275 dtype +
276 py::str("'."
277 " Supported types are: ") +
278 types;
279
280 throw py::value_error(static_cast<std::string>(err_msg));
281 }
282
283 py::dtype second_dtype = dtype_from_typenum(second_type_id);
284 auto second_type_pos =
285 std::find(supported_second_type.begin(),
286 supported_second_type.end(), second_dtype);
287 if (second_type_pos == supported_second_type.end()) {
288 py::str types = py::str(py::cast(supported_second_type));
289 py::str dtype = py::str(second_dtype);
290
291 py::str err_msg =
292 py::str("'" + second_name + "' has unsupported type '") +
293 dtype +
294 py::str("'."
295 " Supported types are: ") +
296 types;
297
298 throw py::value_error(static_cast<std::string>(err_msg));
299 }
300
301 py::str first_dtype_str = py::str(first_dtype);
302 py::str second_dtype_str = py::str(second_dtype);
303 py::str types = py::str(py::cast(all_supported_types));
304
305 py::str err_msg =
306 py::str("'" + first_name + "' and '" + second_name +
307 "' has unsupported types combination: ('") +
308 first_dtype_str + py::str("', '") + second_dtype_str +
309 py::str("')."
310 " Supported types combinations are: ") +
311 types;
312
313 throw py::value_error(static_cast<std::string>(err_msg));
314 }
315
316 return fn;
317 }
318
319 const SupportedDTypeList &get_supported_first_type() const
320 {
321 return supported_first_type;
322 }
323
324 const SupportedDTypeList &get_supported_second_type() const
325 {
326 return supported_second_type;
327 }
328
329 const SupportedDTypeList2 &get_all_supported_types() const
330 {
331 return all_supported_types;
332 }
333
334private:
335 void populate_supported_types()
336 {
337 SupportedTypesSet first_supported_types_set;
338 SupportedTypesSet second_supported_types_set;
339 SupportedTypesSet2 all_supported_types_set;
340
341 for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
342 for (int j = 0; j < dpctl_td_ns::num_types; ++j) {
343 if (table[i][j] != nullptr) {
344 all_supported_types_set.emplace(i, j);
345 first_supported_types_set.emplace(i);
346 second_supported_types_set.emplace(j);
347 }
348 }
349 }
350
351 auto to_supported_dtype_list = [](const auto &supported_set,
352 auto &supported_list) {
353 SupportedTypesList lst(supported_set.begin(), supported_set.end());
354 std::sort(lst.begin(), lst.end());
355 supported_list.resize(supported_set.size());
356 std::transform(lst.begin(), lst.end(), supported_list.begin(),
357 [](TypeId i) { return dtype_from_typenum(i); });
358 };
359
360 to_supported_dtype_list(first_supported_types_set,
361 supported_first_type);
362 to_supported_dtype_list(second_supported_types_set,
363 supported_second_type);
364
365 SupportedTypesList2 lst(all_supported_types_set.begin(),
366 all_supported_types_set.end());
367 std::sort(lst.begin(), lst.end());
368 all_supported_types.resize(all_supported_types_set.size());
369 std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
370 [](TypesPair p) {
371 return DTypePair(dtype_from_typenum(p.first),
372 dtype_from_typenum(p.second));
373 });
374 }
375
376 std::string first_name;
377 std::string second_name;
378
379 SupportedDTypeList supported_first_type;
380 SupportedDTypeList supported_second_type;
381 SupportedDTypeList2 all_supported_types;
382
383 Table2<FnT> table;
384};
385
386} // namespace statistics::common