29#include <pybind11/numpy.h>
30#include <pybind11/pybind11.h>
31#include <sycl/sycl.hpp>
33#include "utils/math_utils.hpp"
34#include "utils/type_utils.hpp"
36namespace type_utils = dpctl::tensor::type_utils;
41template <
typename N,
typename D>
42constexpr auto CeilDiv(N n, D d)
44 return (n + d - 1) / d;
47template <
typename N,
typename D>
48constexpr auto Align(N n, D d)
50 return CeilDiv(n, d) * d;
53template <
typename T, sycl::memory_order Order, sycl::memory_scope Scope>
56 static void add(T &lhs,
const T &value)
58 if constexpr (type_utils::is_complex_v<T>) {
59 using vT =
typename T::value_type;
60 vT *_lhs =
reinterpret_cast<vT(&)[2]
>(lhs);
61 const vT *_val =
reinterpret_cast<const vT(&)[2]
>(value);
67 sycl::atomic_ref<T, Order, Scope> lh(lhs);
76 static bool isnan(
const T &v)
78 if constexpr (type_utils::is_complex_v<T>) {
79 using vT =
typename T::value_type;
81 const vT real1 = std::real(v);
82 const vT imag1 = std::imag(v);
86 else if constexpr (std::is_floating_point_v<T> ||
87 std::is_same_v<T, sycl::half>) {
88 return sycl::isnan(v);
98 bool operator()(
const T &lhs,
const T &rhs)
const
100 if constexpr (type_utils::is_complex_v<T>) {
102 dpctl::tensor::math_utils::less_complex(lhs, rhs);
110template <
typename T,
bool hasValueType>
122 using type =
typename T::value_type;
131size_t get_max_local_size(
const sycl::device &device);
132size_t get_max_local_size(
const sycl::device &device,
133 int cpu_local_size_limit,
134 int gpu_local_size_limit);
136inline size_t get_max_local_size(
const sycl::queue &queue)
138 return get_max_local_size(queue.get_device());
141inline size_t get_max_local_size(
const sycl::queue &queue,
142 int cpu_local_size_limit,
143 int gpu_local_size_limit)
145 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
146 gpu_local_size_limit);
149size_t get_local_mem_size_in_bytes(
const sycl::device &device);
150size_t get_local_mem_size_in_bytes(
const sycl::device &device,
size_t reserve);
152inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue)
154 return get_local_mem_size_in_bytes(queue.get_device());
157inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue,
160 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
164size_t get_local_mem_size_in_items(
const sycl::device &device)
166 return get_local_mem_size_in_bytes(device) /
sizeof(T);
170size_t get_local_mem_size_in_items(
const sycl::device &device,
size_t reserve)
172 return get_local_mem_size_in_bytes(device,
sizeof(T) * reserve) /
sizeof(T);
176inline size_t get_local_mem_size_in_items(
const sycl::queue &queue)
178 return get_local_mem_size_in_items<T>(queue.get_device());
182inline size_t get_local_mem_size_in_items(
const sycl::queue &queue,
185 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
189sycl::nd_range<Dims> make_ndrange(
const sycl::range<Dims> &global_range,
190 const sycl::range<Dims> &local_range,
191 const sycl::range<Dims> &work_per_item)
193 sycl::range<Dims> aligned_global_range;
195 for (
int i = 0; i < Dims; ++i) {
196 aligned_global_range[i] =
197 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
200 return sycl::nd_range<Dims>(aligned_global_range, local_range);
204 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item);
208pybind11::dtype dtype_from_typenum(
int dst_typenum);
212#include "ext/details/common_internal.hpp"