DPNP C++ backend kernel library 0.20.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <complex>
32
33#include <pybind11/numpy.h>
34#include <pybind11/pybind11.h>
35
36#include <sycl/sycl.hpp>
37
38// dpctl tensor headers
39#include "utils/math_utils.hpp"
40#include "utils/type_dispatch.hpp"
41#include "utils/type_utils.hpp"
42
43namespace type_utils = dpctl::tensor::type_utils;
44namespace type_dispatch = dpctl::tensor::type_dispatch;
45
46namespace ext::common
47{
48
49template <typename N, typename D>
50constexpr auto CeilDiv(N n, D d)
51{
52 return (n + d - 1) / d;
53}
54
55template <typename N, typename D>
56constexpr auto Align(N n, D d)
57{
58 return CeilDiv(n, d) * d;
59}
60
61template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
63{
64 static void add(T &lhs, const T &value)
65 {
66 if constexpr (type_utils::is_complex_v<T>) {
67 using vT = typename T::value_type;
68 vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
69 const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
70
71 AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
72 AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
73 }
74 else {
75 sycl::atomic_ref<T, Order, Scope> lh(lhs);
76 lh += value;
77 }
78 }
79};
80
81template <typename T>
82struct Less
83{
84 bool operator()(const T &lhs, const T &rhs) const
85 {
86 if constexpr (type_utils::is_complex_v<T>) {
87 return dpctl::tensor::math_utils::less_complex(lhs, rhs);
88 }
89 else {
90 return std::less{}(lhs, rhs);
91 }
92 }
93};
94
95template <typename T>
96struct IsNan
97{
98 static bool isnan(const T &v)
99 {
100 if constexpr (type_utils::is_complex_v<T>) {
101 using vT = typename T::value_type;
102
103 const vT real1 = std::real(v);
104 const vT imag1 = std::imag(v);
105
106 return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
107 }
108 else if constexpr (std::is_floating_point_v<T> ||
109 std::is_same_v<T, sycl::half>) {
110 return sycl::isnan(v);
111 }
112
113 return false;
114 }
115};
116
117template <typename T, bool hasValueType>
119
120template <typename T>
121struct value_type_of_impl<T, false>
122{
123 using type = T;
124};
125
126template <typename T>
127struct value_type_of_impl<T, true>
128{
129 using type = typename T::value_type;
130};
131
132template <typename T>
134
135template <typename T>
136using value_type_of_t = typename value_type_of<T>::type;
137
138size_t get_max_local_size(const sycl::device &device);
139size_t get_max_local_size(const sycl::device &device,
140 int cpu_local_size_limit,
141 int gpu_local_size_limit);
142
143inline size_t get_max_local_size(const sycl::queue &queue)
144{
145 return get_max_local_size(queue.get_device());
146}
147
148inline size_t get_max_local_size(const sycl::queue &queue,
149 int cpu_local_size_limit,
150 int gpu_local_size_limit)
151{
152 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
153 gpu_local_size_limit);
154}
155
156size_t get_local_mem_size_in_bytes(const sycl::device &device);
157size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve);
158
159inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue)
160{
161 return get_local_mem_size_in_bytes(queue.get_device());
162}
163
164inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue,
165 size_t reserve)
166{
167 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
168}
169
170template <typename T>
171size_t get_local_mem_size_in_items(const sycl::device &device)
172{
173 return get_local_mem_size_in_bytes(device) / sizeof(T);
174}
175
176template <typename T>
177size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
178{
179 return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
180}
181
182template <typename T>
183inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
184{
185 return get_local_mem_size_in_items<T>(queue.get_device());
186}
187
188template <typename T>
189inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
190 size_t reserve)
191{
192 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
193}
194
195template <int Dims>
196sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
197 const sycl::range<Dims> &local_range,
198 const sycl::range<Dims> &work_per_item)
199{
200 sycl::range<Dims> aligned_global_range;
201
202 for (int i = 0; i < Dims; ++i) {
203 aligned_global_range[i] =
204 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
205 }
206
207 return sycl::nd_range<Dims>(aligned_global_range, local_range);
208}
209
210sycl::nd_range<1>
211 make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
212
213// This function is a copy from dpctl because it is not available in the public
214// headers of dpctl.
215pybind11::dtype dtype_from_typenum(int dst_typenum);
216
217template <typename dispatchT,
218 template <typename fnT, typename T>
219 typename factoryT,
220 int _num_types = type_dispatch::num_types>
221inline void init_dispatch_vector(dispatchT dispatch_vector[])
222{
223 type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
224 dvb.populate_dispatch_vector(dispatch_vector);
225}
226
227template <typename dispatchT,
228 template <typename fnT, typename D, typename S>
229 typename factoryT,
230 int _num_types = type_dispatch::num_types>
231inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
232{
233 type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
234 dtb.populate_dispatch_table(dispatch_table);
235}
236} // namespace ext::common
237
238#include "ext/details/common_internal.hpp"