26#include "ext/common.hpp"
27#include "utils/type_dispatch.hpp"
28#include <pybind11/pybind11.h>
30namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
34inline size_t get_max_local_size(
const sycl::device &device)
36 constexpr const int default_max_cpu_local_size = 256;
37 constexpr const int default_max_gpu_local_size = 0;
39 return get_max_local_size(device, default_max_cpu_local_size,
40 default_max_gpu_local_size);
43inline size_t get_max_local_size(
const sycl::device &device,
44 int cpu_local_size_limit,
45 int gpu_local_size_limit)
47 int max_work_group_size =
48 device.get_info<sycl::info::device::max_work_group_size>();
49 if (device.is_cpu() && cpu_local_size_limit > 0) {
50 return std::min(cpu_local_size_limit, max_work_group_size);
52 else if (device.is_gpu() && gpu_local_size_limit > 0) {
53 return std::min(gpu_local_size_limit, max_work_group_size);
56 return max_work_group_size;
59inline sycl::nd_range<1>
60 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item)
62 return make_ndrange(sycl::range<1>(global_size),
63 sycl::range<1>(local_range),
64 sycl::range<1>(work_per_item));
67inline size_t get_local_mem_size_in_bytes(
const sycl::device &device)
70 constexpr const size_t reserve = 1024;
72 return get_local_mem_size_in_bytes(device, reserve);
75inline size_t get_local_mem_size_in_bytes(
const sycl::device &device,
78 size_t local_mem_size =
79 device.get_info<sycl::info::device::local_mem_size>();
80 return local_mem_size - reserve;
83inline pybind11::dtype dtype_from_typenum(
int dst_typenum)
85 dpctl_td_ns::typenum_t dst_typenum_t =
86 static_cast<dpctl_td_ns::typenum_t
>(dst_typenum);
87 switch (dst_typenum_t) {
88 case dpctl_td_ns::typenum_t::BOOL:
89 return py::dtype(
"?");
90 case dpctl_td_ns::typenum_t::INT8:
91 return py::dtype(
"i1");
92 case dpctl_td_ns::typenum_t::UINT8:
93 return py::dtype(
"u1");
94 case dpctl_td_ns::typenum_t::INT16:
95 return py::dtype(
"i2");
96 case dpctl_td_ns::typenum_t::UINT16:
97 return py::dtype(
"u2");
98 case dpctl_td_ns::typenum_t::INT32:
99 return py::dtype(
"i4");
100 case dpctl_td_ns::typenum_t::UINT32:
101 return py::dtype(
"u4");
102 case dpctl_td_ns::typenum_t::INT64:
103 return py::dtype(
"i8");
104 case dpctl_td_ns::typenum_t::UINT64:
105 return py::dtype(
"u8");
106 case dpctl_td_ns::typenum_t::HALF:
107 return py::dtype(
"f2");
108 case dpctl_td_ns::typenum_t::FLOAT:
109 return py::dtype(
"f4");
110 case dpctl_td_ns::typenum_t::DOUBLE:
111 return py::dtype(
"f8");
112 case dpctl_td_ns::typenum_t::CFLOAT:
113 return py::dtype(
"c8");
114 case dpctl_td_ns::typenum_t::CDOUBLE:
115 return py::dtype(
"c16");
117 throw py::value_error(
"Unrecognized dst_typeid");