DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <complex>
32#include <pybind11/numpy.h>
33#include <pybind11/pybind11.h>
34#include <sycl/sycl.hpp>
35
36// dpctl tensor headers
37#include "utils/math_utils.hpp"
38#include "utils/type_dispatch.hpp"
39#include "utils/type_utils.hpp"
40
41namespace type_utils = dpctl::tensor::type_utils;
42namespace type_dispatch = dpctl::tensor::type_dispatch;
43
44namespace ext::common
45{
46
47template <typename N, typename D>
48constexpr auto CeilDiv(N n, D d)
49{
50 return (n + d - 1) / d;
51}
52
53template <typename N, typename D>
54constexpr auto Align(N n, D d)
55{
56 return CeilDiv(n, d) * d;
57}
58
59template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
61{
62 static void add(T &lhs, const T &value)
63 {
64 if constexpr (type_utils::is_complex_v<T>) {
65 using vT = typename T::value_type;
66 vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
67 const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
68
69 AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
70 AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
71 }
72 else {
73 sycl::atomic_ref<T, Order, Scope> lh(lhs);
74 lh += value;
75 }
76 }
77};
78
79template <typename T>
80struct Less
81{
82 bool operator()(const T &lhs, const T &rhs) const
83 {
84 if constexpr (type_utils::is_complex_v<T>) {
85 return dpctl::tensor::math_utils::less_complex(lhs, rhs);
86 }
87 else {
88 return std::less{}(lhs, rhs);
89 }
90 }
91};
92
93template <typename T>
94struct IsNan
95{
96 static bool isnan(const T &v)
97 {
98 if constexpr (type_utils::is_complex_v<T>) {
99 using vT = typename T::value_type;
100
101 const vT real1 = std::real(v);
102 const vT imag1 = std::imag(v);
103
104 return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
105 }
106 else if constexpr (std::is_floating_point_v<T> ||
107 std::is_same_v<T, sycl::half>) {
108 return sycl::isnan(v);
109 }
110
111 return false;
112 }
113};
114
115template <typename T, bool hasValueType>
117
118template <typename T>
119struct value_type_of_impl<T, false>
120{
121 using type = T;
122};
123
124template <typename T>
125struct value_type_of_impl<T, true>
126{
127 using type = typename T::value_type;
128};
129
130template <typename T>
132
133template <typename T>
134using value_type_of_t = typename value_type_of<T>::type;
135
136size_t get_max_local_size(const sycl::device &device);
137size_t get_max_local_size(const sycl::device &device,
138 int cpu_local_size_limit,
139 int gpu_local_size_limit);
140
141inline size_t get_max_local_size(const sycl::queue &queue)
142{
143 return get_max_local_size(queue.get_device());
144}
145
146inline size_t get_max_local_size(const sycl::queue &queue,
147 int cpu_local_size_limit,
148 int gpu_local_size_limit)
149{
150 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
151 gpu_local_size_limit);
152}
153
154size_t get_local_mem_size_in_bytes(const sycl::device &device);
155size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve);
156
157inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue)
158{
159 return get_local_mem_size_in_bytes(queue.get_device());
160}
161
162inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue,
163 size_t reserve)
164{
165 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
166}
167
168template <typename T>
169size_t get_local_mem_size_in_items(const sycl::device &device)
170{
171 return get_local_mem_size_in_bytes(device) / sizeof(T);
172}
173
174template <typename T>
175size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
176{
177 return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
178}
179
180template <typename T>
181inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
182{
183 return get_local_mem_size_in_items<T>(queue.get_device());
184}
185
186template <typename T>
187inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
188 size_t reserve)
189{
190 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
191}
192
193template <int Dims>
194sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
195 const sycl::range<Dims> &local_range,
196 const sycl::range<Dims> &work_per_item)
197{
198 sycl::range<Dims> aligned_global_range;
199
200 for (int i = 0; i < Dims; ++i) {
201 aligned_global_range[i] =
202 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
203 }
204
205 return sycl::nd_range<Dims>(aligned_global_range, local_range);
206}
207
208sycl::nd_range<1>
209 make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
210
211// This function is a copy from dpctl because it is not available in the public
212// headers of dpctl.
213pybind11::dtype dtype_from_typenum(int dst_typenum);
214
215template <typename dispatchT,
216 template <typename fnT, typename T>
217 typename factoryT,
218 int _num_types = type_dispatch::num_types>
219inline void init_dispatch_vector(dispatchT dispatch_vector[])
220{
221 type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
222 dvb.populate_dispatch_vector(dispatch_vector);
223}
224
225template <typename dispatchT,
226 template <typename fnT, typename D, typename S>
227 typename factoryT,
228 int _num_types = type_dispatch::num_types>
229inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
230{
231 type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
232 dtb.populate_dispatch_table(dispatch_table);
233}
234} // namespace ext::common
235
236#include "ext/details/common_internal.hpp"