26#include "ext/common.hpp" 
   27#include "utils/type_dispatch.hpp" 
   28#include <pybind11/pybind11.h> 
   30namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
 
   34inline size_t get_max_local_size(
const sycl::device &device)
 
   36    constexpr const int default_max_cpu_local_size = 256;
 
   37    constexpr const int default_max_gpu_local_size = 0;
 
   39    return get_max_local_size(device, default_max_cpu_local_size,
 
   40                              default_max_gpu_local_size);
 
   43inline size_t get_max_local_size(
const sycl::device &device,
 
   44                                 int cpu_local_size_limit,
 
   45                                 int gpu_local_size_limit)
 
   47    int max_work_group_size =
 
   48        device.get_info<sycl::info::device::max_work_group_size>();
 
   49    if (device.is_cpu() && cpu_local_size_limit > 0) {
 
   50        return std::min(cpu_local_size_limit, max_work_group_size);
 
   52    else if (device.is_gpu() && gpu_local_size_limit > 0) {
 
   53        return std::min(gpu_local_size_limit, max_work_group_size);
 
   56    return max_work_group_size;
 
   59inline sycl::nd_range<1>
 
   60    make_ndrange(
size_t global_size, 
size_t local_range, 
size_t work_per_item)
 
   62    return make_ndrange(sycl::range<1>(global_size),
 
   63                        sycl::range<1>(local_range),
 
   64                        sycl::range<1>(work_per_item));
 
   67inline size_t get_local_mem_size_in_bytes(
const sycl::device &device)
 
   70    constexpr const size_t reserve = 1024;
 
   72    return get_local_mem_size_in_bytes(device, reserve);
 
   75inline size_t get_local_mem_size_in_bytes(
const sycl::device &device,
 
   78    size_t local_mem_size =
 
   79        device.get_info<sycl::info::device::local_mem_size>();
 
   80    return local_mem_size - reserve;
 
   83inline pybind11::dtype dtype_from_typenum(
int dst_typenum)
 
   85    dpctl_td_ns::typenum_t dst_typenum_t =
 
   86        static_cast<dpctl_td_ns::typenum_t
>(dst_typenum);
 
   87    switch (dst_typenum_t) {
 
   88    case dpctl_td_ns::typenum_t::BOOL:
 
   89        return py::dtype(
"?");
 
   90    case dpctl_td_ns::typenum_t::INT8:
 
   91        return py::dtype(
"i1");
 
   92    case dpctl_td_ns::typenum_t::UINT8:
 
   93        return py::dtype(
"u1");
 
   94    case dpctl_td_ns::typenum_t::INT16:
 
   95        return py::dtype(
"i2");
 
   96    case dpctl_td_ns::typenum_t::UINT16:
 
   97        return py::dtype(
"u2");
 
   98    case dpctl_td_ns::typenum_t::INT32:
 
   99        return py::dtype(
"i4");
 
  100    case dpctl_td_ns::typenum_t::UINT32:
 
  101        return py::dtype(
"u4");
 
  102    case dpctl_td_ns::typenum_t::INT64:
 
  103        return py::dtype(
"i8");
 
  104    case dpctl_td_ns::typenum_t::UINT64:
 
  105        return py::dtype(
"u8");
 
  106    case dpctl_td_ns::typenum_t::HALF:
 
  107        return py::dtype(
"f2");
 
  108    case dpctl_td_ns::typenum_t::FLOAT:
 
  109        return py::dtype(
"f4");
 
  110    case dpctl_td_ns::typenum_t::DOUBLE:
 
  111        return py::dtype(
"f8");
 
  112    case dpctl_td_ns::typenum_t::CFLOAT:
 
  113        return py::dtype(
"c8");
 
  114    case dpctl_td_ns::typenum_t::CDOUBLE:
 
  115        return py::dtype(
"c16");
 
  117        throw py::value_error(
"Unrecognized dst_typeid");