29#include "ext/common.hpp"
30#include "utils/type_dispatch.hpp"
31#include <pybind11/pybind11.h>
33namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
37inline size_t get_max_local_size(
const sycl::device &device)
39 constexpr const int default_max_cpu_local_size = 256;
40 constexpr const int default_max_gpu_local_size = 0;
42 return get_max_local_size(device, default_max_cpu_local_size,
43 default_max_gpu_local_size);
46inline size_t get_max_local_size(
const sycl::device &device,
47 int cpu_local_size_limit,
48 int gpu_local_size_limit)
50 int max_work_group_size =
51 device.get_info<sycl::info::device::max_work_group_size>();
52 if (device.is_cpu() && cpu_local_size_limit > 0) {
53 return std::min(cpu_local_size_limit, max_work_group_size);
55 else if (device.is_gpu() && gpu_local_size_limit > 0) {
56 return std::min(gpu_local_size_limit, max_work_group_size);
59 return max_work_group_size;
62inline sycl::nd_range<1>
63 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item)
65 return make_ndrange(sycl::range<1>(global_size),
66 sycl::range<1>(local_range),
67 sycl::range<1>(work_per_item));
70inline size_t get_local_mem_size_in_bytes(
const sycl::device &device)
73 constexpr const size_t reserve = 1024;
75 return get_local_mem_size_in_bytes(device, reserve);
78inline size_t get_local_mem_size_in_bytes(
const sycl::device &device,
81 size_t local_mem_size =
82 device.get_info<sycl::info::device::local_mem_size>();
83 return local_mem_size - reserve;
86inline pybind11::dtype dtype_from_typenum(
int dst_typenum)
88 dpctl_td_ns::typenum_t dst_typenum_t =
89 static_cast<dpctl_td_ns::typenum_t
>(dst_typenum);
90 switch (dst_typenum_t) {
91 case dpctl_td_ns::typenum_t::BOOL:
92 return py::dtype(
"?");
93 case dpctl_td_ns::typenum_t::INT8:
94 return py::dtype(
"i1");
95 case dpctl_td_ns::typenum_t::UINT8:
96 return py::dtype(
"u1");
97 case dpctl_td_ns::typenum_t::INT16:
98 return py::dtype(
"i2");
99 case dpctl_td_ns::typenum_t::UINT16:
100 return py::dtype(
"u2");
101 case dpctl_td_ns::typenum_t::INT32:
102 return py::dtype(
"i4");
103 case dpctl_td_ns::typenum_t::UINT32:
104 return py::dtype(
"u4");
105 case dpctl_td_ns::typenum_t::INT64:
106 return py::dtype(
"i8");
107 case dpctl_td_ns::typenum_t::UINT64:
108 return py::dtype(
"u8");
109 case dpctl_td_ns::typenum_t::HALF:
110 return py::dtype(
"f2");
111 case dpctl_td_ns::typenum_t::FLOAT:
112 return py::dtype(
"f4");
113 case dpctl_td_ns::typenum_t::DOUBLE:
114 return py::dtype(
"f8");
115 case dpctl_td_ns::typenum_t::CFLOAT:
116 return py::dtype(
"c8");
117 case dpctl_td_ns::typenum_t::CDOUBLE:
118 return py::dtype(
"c16");
120 throw py::value_error(
"Unrecognized dst_typeid");