33#include <pybind11/numpy.h>
34#include <pybind11/pybind11.h>
36#include <sycl/sycl.hpp>
39#include "utils/math_utils.hpp"
40#include "utils/type_dispatch.hpp"
41#include "utils/type_utils.hpp"
43namespace type_utils = dpctl::tensor::type_utils;
44namespace type_dispatch = dpctl::tensor::type_dispatch;
49template <
typename N,
typename D>
50constexpr auto CeilDiv(N n, D d)
52 return (n + d - 1) / d;
55template <
typename N,
typename D>
56constexpr auto Align(N n, D d)
58 return CeilDiv(n, d) * d;
61template <
typename T, sycl::memory_order Order, sycl::memory_scope Scope>
64 static void add(T &lhs,
const T &value)
66 if constexpr (type_utils::is_complex_v<T>) {
67 using vT =
typename T::value_type;
68 vT *_lhs =
reinterpret_cast<vT(&)[2]
>(lhs);
69 const vT *_val =
reinterpret_cast<const vT(&)[2]
>(value);
75 sycl::atomic_ref<T, Order, Scope> lh(lhs);
84 bool operator()(
const T &lhs,
const T &rhs)
const
86 if constexpr (type_utils::is_complex_v<T>) {
87 return dpctl::tensor::math_utils::less_complex(lhs, rhs);
90 return std::less{}(lhs, rhs);
98 static bool isnan(
const T &v)
100 if constexpr (type_utils::is_complex_v<T>) {
101 using vT =
typename T::value_type;
103 const vT real1 = std::real(v);
104 const vT imag1 = std::imag(v);
108 else if constexpr (std::is_floating_point_v<T> ||
109 std::is_same_v<T, sycl::half>) {
110 return sycl::isnan(v);
117template <
typename T,
bool hasValueType>
129 using type =
typename T::value_type;
138size_t get_max_local_size(
const sycl::device &device);
139size_t get_max_local_size(
const sycl::device &device,
140 int cpu_local_size_limit,
141 int gpu_local_size_limit);
143inline size_t get_max_local_size(
const sycl::queue &queue)
145 return get_max_local_size(queue.get_device());
148inline size_t get_max_local_size(
const sycl::queue &queue,
149 int cpu_local_size_limit,
150 int gpu_local_size_limit)
152 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
153 gpu_local_size_limit);
156size_t get_local_mem_size_in_bytes(
const sycl::device &device);
157size_t get_local_mem_size_in_bytes(
const sycl::device &device,
size_t reserve);
159inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue)
161 return get_local_mem_size_in_bytes(queue.get_device());
164inline size_t get_local_mem_size_in_bytes(
const sycl::queue &queue,
167 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
171size_t get_local_mem_size_in_items(
const sycl::device &device)
173 return get_local_mem_size_in_bytes(device) /
sizeof(T);
177size_t get_local_mem_size_in_items(
const sycl::device &device,
size_t reserve)
179 return get_local_mem_size_in_bytes(device,
sizeof(T) * reserve) /
sizeof(T);
183inline size_t get_local_mem_size_in_items(
const sycl::queue &queue)
185 return get_local_mem_size_in_items<T>(queue.get_device());
189inline size_t get_local_mem_size_in_items(
const sycl::queue &queue,
192 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
196sycl::nd_range<Dims> make_ndrange(
const sycl::range<Dims> &global_range,
197 const sycl::range<Dims> &local_range,
198 const sycl::range<Dims> &work_per_item)
200 sycl::range<Dims> aligned_global_range;
202 for (
int i = 0; i < Dims; ++i) {
203 aligned_global_range[i] =
204 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
207 return sycl::nd_range<Dims>(aligned_global_range, local_range);
211 make_ndrange(
size_t global_size,
size_t local_range,
size_t work_per_item);
215pybind11::dtype dtype_from_typenum(
int dst_typenum);
217template <
typename dispatchT,
218 template <
typename fnT,
typename T>
220 int _num_types = type_dispatch::num_types>
221inline void init_dispatch_vector(dispatchT dispatch_vector[])
223 type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
224 dvb.populate_dispatch_vector(dispatch_vector);
227template <
typename dispatchT,
228 template <
typename fnT,
typename D,
typename S>
230 int _num_types = type_dispatch::num_types>
231inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
233 type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
234 dtb.populate_dispatch_table(dispatch_table);
238#include "ext/details/common_internal.hpp"