29#include <pybind11/numpy.h> 
   30#include <pybind11/pybind11.h> 
   31#include <sycl/sycl.hpp> 
   34#include "utils/math_utils.hpp" 
   35#include "utils/type_dispatch.hpp" 
   36#include "utils/type_utils.hpp" 
   38namespace type_utils = dpctl::tensor::type_utils;
 
   39namespace type_dispatch = dpctl::tensor::type_dispatch;
 
   44template <
typename N, 
typename D>
 
   45constexpr auto CeilDiv(N n, D d)
 
   47    return (n + d - 1) / d;
 
   50template <
typename N, 
typename D>
 
   51constexpr auto Align(N n, D d)
 
   53    return CeilDiv(n, d) * d;
 
   56template <
typename T, sycl::memory_order Order, sycl::memory_scope Scope>
 
   59    static void add(T &lhs, 
const T &value)
 
   61        if constexpr (type_utils::is_complex_v<T>) {
 
   62            using vT = 
typename T::value_type;
 
   63            vT *_lhs = 
reinterpret_cast<vT(&)[2]
>(lhs);
 
   64            const vT *_val = 
reinterpret_cast<const vT(&)[2]
>(value);
 
   70            sycl::atomic_ref<T, Order, Scope> lh(lhs);
 
 
   79    bool operator()(
const T &lhs, 
const T &rhs)
 const 
   81        if constexpr (type_utils::is_complex_v<T>) {
 
   82            return dpctl::tensor::math_utils::less_complex(lhs, rhs);
 
   85            return std::less{}(lhs, rhs);
 
 
   93    static bool isnan(
const T &v)
 
   95        if constexpr (type_utils::is_complex_v<T>) {
 
   96            using vT = 
typename T::value_type;
 
   98            const vT real1 = std::real(v);
 
   99            const vT imag1 = std::imag(v);
 
  103        else if constexpr (std::is_floating_point_v<T> ||
 
  104                           std::is_same_v<T, sycl::half>) {
 
  105            return sycl::isnan(v);
 
 
  112template <
typename T, 
bool hasValueType>
 
  124    using type = 
typename T::value_type;
 
 
  133size_t get_max_local_size(
const sycl::device &device);
 
  134size_t get_max_local_size(
const sycl::device &device,
 
  135                          int cpu_local_size_limit,
 
  136                          int gpu_local_size_limit);
 
  138inline size_t get_max_local_size(
const sycl::queue &queue)
 
  140    return get_max_local_size(queue.get_device());
 
  143inline size_t get_max_local_size(
const sycl::queue &queue,
 
  144                                 int cpu_local_size_limit,
 
  145                                 int gpu_local_size_limit)
 
  147    return get_max_local_size(queue.get_device(), cpu_local_size_limit,
 
  148                              gpu_local_size_limit);
 
  151size_t get_local_mem_size_in_bytes(
const sycl::device &device);
 
  152size_t get_local_mem_size_in_bytes(
const sycl::device &device, 
size_t reserve);
 
  154inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue)
 
  156    return get_local_mem_size_in_bytes(queue.get_device());
 
  159inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue,
 
  162    return get_local_mem_size_in_bytes(queue.get_device(), reserve);
 
  166size_t get_local_mem_size_in_items(
const sycl::device &device)
 
  168    return get_local_mem_size_in_bytes(device) / 
sizeof(T);
 
  172size_t get_local_mem_size_in_items(
const sycl::device &device, 
size_t reserve)
 
  174    return get_local_mem_size_in_bytes(device, 
sizeof(T) * reserve) / 
sizeof(T);
 
  178inline size_t get_local_mem_size_in_items(
const sycl::queue &queue)
 
  180    return get_local_mem_size_in_items<T>(queue.get_device());
 
  184inline size_t get_local_mem_size_in_items(
const sycl::queue &queue,
 
  187    return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
 
  191sycl::nd_range<Dims> make_ndrange(
const sycl::range<Dims> &global_range,
 
  192                                  const sycl::range<Dims> &local_range,
 
  193                                  const sycl::range<Dims> &work_per_item)
 
  195    sycl::range<Dims> aligned_global_range;
 
  197    for (
int i = 0; i < Dims; ++i) {
 
  198        aligned_global_range[i] =
 
  199            Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
 
  202    return sycl::nd_range<Dims>(aligned_global_range, local_range);
 
  206    make_ndrange(
size_t global_size, 
size_t local_range, 
size_t work_per_item);
 
  210pybind11::dtype dtype_from_typenum(
int dst_typenum);
 
  212template <
typename dispatchT,
 
  213          template <
typename fnT, 
typename T>
 
  215          int _num_types = type_dispatch::num_types>
 
  216inline void init_dispatch_vector(dispatchT dispatch_vector[])
 
  218    type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
 
  219    dvb.populate_dispatch_vector(dispatch_vector);
 
  222template <
typename dispatchT,
 
  223          template <
typename fnT, 
typename D, 
typename S>
 
  225          int _num_types = type_dispatch::num_types>
 
  226inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
 
  228    type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
 
  229    dtb.populate_dispatch_table(dispatch_table);
 
  233#include "ext/details/common_internal.hpp"