156 template <
typename SupportedTypes,
template <
typename>
typename Func>
157 void populate_dispatch_table()
159 using TBulder =
typename TableBuilder<FnT, SupportedTypes, Func>::type;
162 builder.populate_dispatch_vector(table);
163 populate_supported_types();
166 FnT get_unsafe(
int _typenum)
const
168 auto array_types = dpctl_td_ns::usm_ndarray_types();
169 const int type_id = array_types.typenum_to_lookup_id(_typenum);
171 return table[type_id];
174 FnT get(
int _typenum)
const
176 auto fn = get_unsafe(_typenum);
179 auto array_types = dpctl_td_ns::usm_ndarray_types();
180 const int _type_id = array_types.typenum_to_lookup_id(_typenum);
182 py::dtype _dtype = dtype_from_typenum(_type_id);
183 auto _type_pos = std::find(supported_types.begin(),
184 supported_types.end(), _dtype);
185 if (_type_pos == supported_types.end()) {
186 py::str types = py::str(py::cast(supported_types));
187 py::str dtype = py::str(_dtype);
190 py::str(
"'" + name +
"' has unsupported type '") + dtype +
192 " Supported types are: ") +
195 throw py::value_error(
static_cast<std::string
>(err_msg));
202 const SupportedDTypeList &get_all_supported_types()
const
204 return supported_types;
208 void populate_supported_types()
210 for (
int i = 0; i < dpctl_td_ns::num_types; ++i) {
211 if (table[i] !=
nullptr) {
212 supported_types.emplace_back(dtype_from_typenum(i));
218 SupportedDTypeList supported_types;
227 : first_name(first_name), second_name(second_name)
231 template <
typename SupportedTypes,
232 template <
typename,
typename>
234 void populate_dispatch_table()
236 using TBulder =
typename TableBuilder2<FnT, SupportedTypes, Func>::type;
239 builder.populate_dispatch_table(table);
240 populate_supported_types();
243 FnT get_unsafe(
int first_typenum,
int second_typenum)
const
245 auto array_types = dpctl_td_ns::usm_ndarray_types();
246 const int first_type_id =
247 array_types.typenum_to_lookup_id(first_typenum);
248 const int second_type_id =
249 array_types.typenum_to_lookup_id(second_typenum);
251 return table[first_type_id][second_type_id];
254 FnT get(
int first_typenum,
int second_typenum)
const
256 auto fn = get_unsafe(first_typenum, second_typenum);
259 auto array_types = dpctl_td_ns::usm_ndarray_types();
260 const int first_type_id =
261 array_types.typenum_to_lookup_id(first_typenum);
262 const int second_type_id =
263 array_types.typenum_to_lookup_id(second_typenum);
265 py::dtype first_dtype = dtype_from_typenum(first_type_id);
266 auto first_type_pos =
267 std::find(supported_first_type.begin(),
268 supported_first_type.end(), first_dtype);
269 if (first_type_pos == supported_first_type.end()) {
270 py::str types = py::str(py::cast(supported_first_type));
271 py::str dtype = py::str(first_dtype);
274 py::str(
"'" + first_name +
"' has unsupported type '") +
277 " Supported types are: ") +
280 throw py::value_error(
static_cast<std::string
>(err_msg));
283 py::dtype second_dtype = dtype_from_typenum(second_type_id);
284 auto second_type_pos =
285 std::find(supported_second_type.begin(),
286 supported_second_type.end(), second_dtype);
287 if (second_type_pos == supported_second_type.end()) {
288 py::str types = py::str(py::cast(supported_second_type));
289 py::str dtype = py::str(second_dtype);
292 py::str(
"'" + second_name +
"' has unsupported type '") +
295 " Supported types are: ") +
298 throw py::value_error(
static_cast<std::string
>(err_msg));
301 py::str first_dtype_str = py::str(first_dtype);
302 py::str second_dtype_str = py::str(second_dtype);
303 py::str types = py::str(py::cast(all_supported_types));
306 py::str(
"'" + first_name +
"' and '" + second_name +
307 "' has unsupported types combination: ('") +
308 first_dtype_str + py::str(
"', '") + second_dtype_str +
310 " Supported types combinations are: ") +
313 throw py::value_error(
static_cast<std::string
>(err_msg));
319 const SupportedDTypeList &get_supported_first_type()
const
321 return supported_first_type;
324 const SupportedDTypeList &get_supported_second_type()
const
326 return supported_second_type;
329 const SupportedDTypeList2 &get_all_supported_types()
const
331 return all_supported_types;
335 void populate_supported_types()
337 SupportedTypesSet first_supported_types_set;
338 SupportedTypesSet second_supported_types_set;
339 SupportedTypesSet2 all_supported_types_set;
341 for (
int i = 0; i < dpctl_td_ns::num_types; ++i) {
342 for (
int j = 0; j < dpctl_td_ns::num_types; ++j) {
343 if (table[i][j] !=
nullptr) {
344 all_supported_types_set.emplace(i, j);
345 first_supported_types_set.emplace(i);
346 second_supported_types_set.emplace(j);
351 auto to_supported_dtype_list = [](
const auto &supported_set,
352 auto &supported_list) {
353 SupportedTypesList lst(supported_set.begin(), supported_set.end());
354 std::sort(lst.begin(), lst.end());
355 supported_list.resize(supported_set.size());
356 std::transform(lst.begin(), lst.end(), supported_list.begin(),
357 [](TypeId i) { return dtype_from_typenum(i); });
360 to_supported_dtype_list(first_supported_types_set,
361 supported_first_type);
362 to_supported_dtype_list(second_supported_types_set,
363 supported_second_type);
365 SupportedTypesList2 lst(all_supported_types_set.begin(),
366 all_supported_types_set.end());
367 std::sort(lst.begin(), lst.end());
368 all_supported_types.resize(all_supported_types_set.size());
369 std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
371 return DTypePair(dtype_from_typenum(p.first),
372 dtype_from_typenum(p.second));
376 std::string first_name;
377 std::string second_name;
379 SupportedDTypeList supported_first_type;
380 SupportedDTypeList supported_second_type;
381 SupportedDTypeList2 all_supported_types;