159 template <
typename SupportedTypes,
template <
typename>
typename Func>
160 void populate_dispatch_table()
162 using TBulder =
typename TableBuilder<FnT, SupportedTypes, Func>::type;
165 builder.populate_dispatch_vector(table);
166 populate_supported_types();
169 FnT get_unsafe(
int _typenum)
const
171 auto array_types = dpctl_td_ns::usm_ndarray_types();
172 const int type_id = array_types.typenum_to_lookup_id(_typenum);
174 return table[type_id];
177 FnT get(
int _typenum)
const
179 auto fn = get_unsafe(_typenum);
182 auto array_types = dpctl_td_ns::usm_ndarray_types();
183 const int _type_id = array_types.typenum_to_lookup_id(_typenum);
185 py::dtype _dtype = dtype_from_typenum(_type_id);
186 auto _type_pos = std::find(supported_types.begin(),
187 supported_types.end(), _dtype);
188 if (_type_pos == supported_types.end()) {
189 py::str types = py::str(py::cast(supported_types));
190 py::str dtype = py::str(_dtype);
193 py::str(
"'" + name +
"' has unsupported type '") + dtype +
195 " Supported types are: ") +
198 throw py::value_error(
static_cast<std::string
>(err_msg));
205 const SupportedDTypeList &get_all_supported_types()
const
207 return supported_types;
211 void populate_supported_types()
213 for (
int i = 0; i < dpctl_td_ns::num_types; ++i) {
214 if (table[i] !=
nullptr) {
215 supported_types.emplace_back(dtype_from_typenum(i));
221 SupportedDTypeList supported_types;
230 : first_name(first_name), second_name(second_name)
234 template <
typename SupportedTypes,
235 template <
typename,
typename>
237 void populate_dispatch_table()
239 using TBulder =
typename TableBuilder2<FnT, SupportedTypes, Func>::type;
242 builder.populate_dispatch_table(table);
243 populate_supported_types();
246 FnT get_unsafe(
int first_typenum,
int second_typenum)
const
248 auto array_types = dpctl_td_ns::usm_ndarray_types();
249 const int first_type_id =
250 array_types.typenum_to_lookup_id(first_typenum);
251 const int second_type_id =
252 array_types.typenum_to_lookup_id(second_typenum);
254 return table[first_type_id][second_type_id];
257 FnT get(
int first_typenum,
int second_typenum)
const
259 auto fn = get_unsafe(first_typenum, second_typenum);
262 auto array_types = dpctl_td_ns::usm_ndarray_types();
263 const int first_type_id =
264 array_types.typenum_to_lookup_id(first_typenum);
265 const int second_type_id =
266 array_types.typenum_to_lookup_id(second_typenum);
268 py::dtype first_dtype = dtype_from_typenum(first_type_id);
269 auto first_type_pos =
270 std::find(supported_first_type.begin(),
271 supported_first_type.end(), first_dtype);
272 if (first_type_pos == supported_first_type.end()) {
273 py::str types = py::str(py::cast(supported_first_type));
274 py::str dtype = py::str(first_dtype);
277 py::str(
"'" + first_name +
"' has unsupported type '") +
280 " Supported types are: ") +
283 throw py::value_error(
static_cast<std::string
>(err_msg));
286 py::dtype second_dtype = dtype_from_typenum(second_type_id);
287 auto second_type_pos =
288 std::find(supported_second_type.begin(),
289 supported_second_type.end(), second_dtype);
290 if (second_type_pos == supported_second_type.end()) {
291 py::str types = py::str(py::cast(supported_second_type));
292 py::str dtype = py::str(second_dtype);
295 py::str(
"'" + second_name +
"' has unsupported type '") +
298 " Supported types are: ") +
301 throw py::value_error(
static_cast<std::string
>(err_msg));
304 py::str first_dtype_str = py::str(first_dtype);
305 py::str second_dtype_str = py::str(second_dtype);
306 py::str types = py::str(py::cast(all_supported_types));
309 py::str(
"'" + first_name +
"' and '" + second_name +
310 "' has unsupported types combination: ('") +
311 first_dtype_str + py::str(
"', '") + second_dtype_str +
313 " Supported types combinations are: ") +
316 throw py::value_error(
static_cast<std::string
>(err_msg));
322 const SupportedDTypeList &get_supported_first_type()
const
324 return supported_first_type;
327 const SupportedDTypeList &get_supported_second_type()
const
329 return supported_second_type;
332 const SupportedDTypeList2 &get_all_supported_types()
const
334 return all_supported_types;
338 void populate_supported_types()
340 SupportedTypesSet first_supported_types_set;
341 SupportedTypesSet second_supported_types_set;
342 SupportedTypesSet2 all_supported_types_set;
344 for (
int i = 0; i < dpctl_td_ns::num_types; ++i) {
345 for (
int j = 0; j < dpctl_td_ns::num_types; ++j) {
346 if (table[i][j] !=
nullptr) {
347 all_supported_types_set.emplace(i, j);
348 first_supported_types_set.emplace(i);
349 second_supported_types_set.emplace(j);
354 auto to_supported_dtype_list = [](
const auto &supported_set,
355 auto &supported_list) {
356 SupportedTypesList lst(supported_set.begin(), supported_set.end());
357 std::sort(lst.begin(), lst.end());
358 supported_list.resize(supported_set.size());
359 std::transform(lst.begin(), lst.end(), supported_list.begin(),
360 [](TypeId i) { return dtype_from_typenum(i); });
363 to_supported_dtype_list(first_supported_types_set,
364 supported_first_type);
365 to_supported_dtype_list(second_supported_types_set,
366 supported_second_type);
368 SupportedTypesList2 lst(all_supported_types_set.begin(),
369 all_supported_types_set.end());
370 std::sort(lst.begin(), lst.end());
371 all_supported_types.resize(all_supported_types_set.size());
372 std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
374 return DTypePair(dtype_from_typenum(p.first),
375 dtype_from_typenum(p.second));
379 std::string first_name;
380 std::string second_name;
382 SupportedDTypeList supported_first_type;
383 SupportedDTypeList supported_second_type;
384 SupportedDTypeList2 all_supported_types;