33#include <pybind11/numpy.h>
34#include <pybind11/pybind11.h>
36#include "ext/common.hpp"
37#include "utils/type_dispatch.hpp"
39namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
43inline size_t get_max_local_size(
const sycl::device &device)
45 constexpr const int default_max_cpu_local_size = 256;
46 constexpr const int default_max_gpu_local_size = 0;
48 return get_max_local_size(device, default_max_cpu_local_size,
49 default_max_gpu_local_size);
52inline size_t get_max_local_size(
const sycl::device &device,
53 int cpu_local_size_limit,
54 int gpu_local_size_limit)
56 int max_work_group_size =
57 device.get_info<sycl::info::device::max_work_group_size>();
58 if (device.is_cpu() && cpu_local_size_limit > 0) {
59 return std::min(cpu_local_size_limit, max_work_group_size);
61 else if (device.is_gpu() && gpu_local_size_limit > 0) {
62 return std::min(gpu_local_size_limit, max_work_group_size);
65 return max_work_group_size;
68inline sycl::nd_range<1>
69 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item)
71 return make_ndrange(sycl::range<1>(global_size),
72 sycl::range<1>(local_range),
73 sycl::range<1>(work_per_item));
76inline size_t get_local_mem_size_in_bytes(
const sycl::device &device)
79 constexpr const size_t reserve = 1024;
81 return get_local_mem_size_in_bytes(device, reserve);
84inline size_t get_local_mem_size_in_bytes(
const sycl::device &device,
87 size_t local_mem_size =
88 device.get_info<sycl::info::device::local_mem_size>();
89 return local_mem_size - reserve;
92inline pybind11::dtype dtype_from_typenum(
int dst_typenum)
94 dpctl_td_ns::typenum_t dst_typenum_t =
95 static_cast<dpctl_td_ns::typenum_t
>(dst_typenum);
96 switch (dst_typenum_t) {
97 case dpctl_td_ns::typenum_t::BOOL:
98 return py::dtype(
"?");
99 case dpctl_td_ns::typenum_t::INT8:
100 return py::dtype(
"i1");
101 case dpctl_td_ns::typenum_t::UINT8:
102 return py::dtype(
"u1");
103 case dpctl_td_ns::typenum_t::INT16:
104 return py::dtype(
"i2");
105 case dpctl_td_ns::typenum_t::UINT16:
106 return py::dtype(
"u2");
107 case dpctl_td_ns::typenum_t::INT32:
108 return py::dtype(
"i4");
109 case dpctl_td_ns::typenum_t::UINT32:
110 return py::dtype(
"u4");
111 case dpctl_td_ns::typenum_t::INT64:
112 return py::dtype(
"i8");
113 case dpctl_td_ns::typenum_t::UINT64:
114 return py::dtype(
"u8");
115 case dpctl_td_ns::typenum_t::HALF:
116 return py::dtype(
"f2");
117 case dpctl_td_ns::typenum_t::FLOAT:
118 return py::dtype(
"f4");
119 case dpctl_td_ns::typenum_t::DOUBLE:
120 return py::dtype(
"f8");
121 case dpctl_td_ns::typenum_t::CFLOAT:
122 return py::dtype(
"c8");
123 case dpctl_td_ns::typenum_t::CDOUBLE:
124 return py::dtype(
"c16");
126 throw py::value_error(
"Unrecognized dst_typeid");