33#include "ext/common.hpp"
34#include "utils/type_dispatch.hpp"
35#include <pybind11/pybind11.h>
37namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
41inline size_t get_max_local_size(
const sycl::device &device)
43 constexpr const int default_max_cpu_local_size = 256;
44 constexpr const int default_max_gpu_local_size = 0;
46 return get_max_local_size(device, default_max_cpu_local_size,
47 default_max_gpu_local_size);
50inline size_t get_max_local_size(
const sycl::device &device,
51 int cpu_local_size_limit,
52 int gpu_local_size_limit)
54 int max_work_group_size =
55 device.get_info<sycl::info::device::max_work_group_size>();
56 if (device.is_cpu() && cpu_local_size_limit > 0) {
57 return std::min(cpu_local_size_limit, max_work_group_size);
59 else if (device.is_gpu() && gpu_local_size_limit > 0) {
60 return std::min(gpu_local_size_limit, max_work_group_size);
63 return max_work_group_size;
66inline sycl::nd_range<1>
67 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item)
69 return make_ndrange(sycl::range<1>(global_size),
70 sycl::range<1>(local_range),
71 sycl::range<1>(work_per_item));
74inline size_t get_local_mem_size_in_bytes(
const sycl::device &device)
77 constexpr const size_t reserve = 1024;
79 return get_local_mem_size_in_bytes(device, reserve);
82inline size_t get_local_mem_size_in_bytes(
const sycl::device &device,
85 size_t local_mem_size =
86 device.get_info<sycl::info::device::local_mem_size>();
87 return local_mem_size - reserve;
90inline pybind11::dtype dtype_from_typenum(
int dst_typenum)
92 dpctl_td_ns::typenum_t dst_typenum_t =
93 static_cast<dpctl_td_ns::typenum_t
>(dst_typenum);
94 switch (dst_typenum_t) {
95 case dpctl_td_ns::typenum_t::BOOL:
96 return py::dtype(
"?");
97 case dpctl_td_ns::typenum_t::INT8:
98 return py::dtype(
"i1");
99 case dpctl_td_ns::typenum_t::UINT8:
100 return py::dtype(
"u1");
101 case dpctl_td_ns::typenum_t::INT16:
102 return py::dtype(
"i2");
103 case dpctl_td_ns::typenum_t::UINT16:
104 return py::dtype(
"u2");
105 case dpctl_td_ns::typenum_t::INT32:
106 return py::dtype(
"i4");
107 case dpctl_td_ns::typenum_t::UINT32:
108 return py::dtype(
"u4");
109 case dpctl_td_ns::typenum_t::INT64:
110 return py::dtype(
"i8");
111 case dpctl_td_ns::typenum_t::UINT64:
112 return py::dtype(
"u8");
113 case dpctl_td_ns::typenum_t::HALF:
114 return py::dtype(
"f2");
115 case dpctl_td_ns::typenum_t::FLOAT:
116 return py::dtype(
"f4");
117 case dpctl_td_ns::typenum_t::DOUBLE:
118 return py::dtype(
"f8");
119 case dpctl_td_ns::typenum_t::CFLOAT:
120 return py::dtype(
"c8");
121 case dpctl_td_ns::typenum_t::CDOUBLE:
122 return py::dtype(
"c16");
124 throw py::value_error(
"Unrecognized dst_typeid");