DPNP C++ backend kernel library 0.18.0rc1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <complex>
29#include <pybind11/numpy.h>
30#include <pybind11/pybind11.h>
31#include <sycl/sycl.hpp>
32
33#include "utils/math_utils.hpp"
34#include "utils/type_utils.hpp"
35
36namespace type_utils = dpctl::tensor::type_utils;
37
38namespace ext::common
39{
40
41template <typename N, typename D>
42constexpr auto CeilDiv(N n, D d)
43{
44 return (n + d - 1) / d;
45}
46
47template <typename N, typename D>
48constexpr auto Align(N n, D d)
49{
50 return CeilDiv(n, d) * d;
51}
52
53template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
55{
56 static void add(T &lhs, const T &value)
57 {
58 if constexpr (type_utils::is_complex_v<T>) {
59 using vT = typename T::value_type;
60 vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
61 const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
62
63 AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
64 AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
65 }
66 else {
67 sycl::atomic_ref<T, Order, Scope> lh(lhs);
68 lh += value;
69 }
70 }
71};
72
73template <typename T>
74struct Less
75{
76 bool operator()(const T &lhs, const T &rhs) const
77 {
78 if constexpr (type_utils::is_complex_v<T>) {
79 return dpctl::tensor::math_utils::less_complex(lhs, rhs);
80 }
81 else {
82 return std::less{}(lhs, rhs);
83 }
84 }
85};
86
87template <typename T>
88struct IsNan
89{
90 static bool isnan(const T &v)
91 {
92 if constexpr (type_utils::is_complex_v<T>) {
93 using vT = typename T::value_type;
94
95 const vT real1 = std::real(v);
96 const vT imag1 = std::imag(v);
97
98 return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
99 }
100 else if constexpr (std::is_floating_point_v<T> ||
101 std::is_same_v<T, sycl::half>) {
102 return sycl::isnan(v);
103 }
104
105 return false;
106 }
107};
108
109template <typename T, bool hasValueType>
111
112template <typename T>
113struct value_type_of_impl<T, false>
114{
115 using type = T;
116};
117
118template <typename T>
119struct value_type_of_impl<T, true>
120{
121 using type = typename T::value_type;
122};
123
124template <typename T>
126
127template <typename T>
128using value_type_of_t = typename value_type_of<T>::type;
129
130size_t get_max_local_size(const sycl::device &device);
131size_t get_max_local_size(const sycl::device &device,
132 int cpu_local_size_limit,
133 int gpu_local_size_limit);
134
135inline size_t get_max_local_size(const sycl::queue &queue)
136{
137 return get_max_local_size(queue.get_device());
138}
139
140inline size_t get_max_local_size(const sycl::queue &queue,
141 int cpu_local_size_limit,
142 int gpu_local_size_limit)
143{
144 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
145 gpu_local_size_limit);
146}
147
148size_t get_local_mem_size_in_bytes(const sycl::device &device);
149size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve);
150
151inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue)
152{
153 return get_local_mem_size_in_bytes(queue.get_device());
154}
155
156inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue,
157 size_t reserve)
158{
159 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
160}
161
162template <typename T>
163size_t get_local_mem_size_in_items(const sycl::device &device)
164{
165 return get_local_mem_size_in_bytes(device) / sizeof(T);
166}
167
168template <typename T>
169size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
170{
171 return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
172}
173
174template <typename T>
175inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
176{
177 return get_local_mem_size_in_items<T>(queue.get_device());
178}
179
180template <typename T>
181inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
182 size_t reserve)
183{
184 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
185}
186
187template <int Dims>
188sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
189 const sycl::range<Dims> &local_range,
190 const sycl::range<Dims> &work_per_item)
191{
192 sycl::range<Dims> aligned_global_range;
193
194 for (int i = 0; i < Dims; ++i) {
195 aligned_global_range[i] =
196 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
197 }
198
199 return sycl::nd_range<Dims>(aligned_global_range, local_range);
200}
201
202sycl::nd_range<1>
203 make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
204
205// This function is a copy from dpctl because it is not available in the public
206// headers of dpctl.
207pybind11::dtype dtype_from_typenum(int dst_typenum);
208
209} // namespace ext::common
210
211#include "ext/details/common_internal.hpp"