DPNP C++ backend kernel library 0.19.0dev6
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <complex>
29#include <pybind11/numpy.h>
30#include <pybind11/pybind11.h>
31#include <sycl/sycl.hpp>
32
33// dpctl tensor headers
34#include "utils/math_utils.hpp"
35#include "utils/type_dispatch.hpp"
36#include "utils/type_utils.hpp"
37
38namespace type_utils = dpctl::tensor::type_utils;
39namespace type_dispatch = dpctl::tensor::type_dispatch;
40
41namespace ext::common
42{
43
44template <typename N, typename D>
45constexpr auto CeilDiv(N n, D d)
46{
47 return (n + d - 1) / d;
48}
49
50template <typename N, typename D>
51constexpr auto Align(N n, D d)
52{
53 return CeilDiv(n, d) * d;
54}
55
56template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
58{
59 static void add(T &lhs, const T &value)
60 {
61 if constexpr (type_utils::is_complex_v<T>) {
62 using vT = typename T::value_type;
63 vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
64 const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
65
66 AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
67 AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
68 }
69 else {
70 sycl::atomic_ref<T, Order, Scope> lh(lhs);
71 lh += value;
72 }
73 }
74};
75
76template <typename T>
77struct Less
78{
79 bool operator()(const T &lhs, const T &rhs) const
80 {
81 if constexpr (type_utils::is_complex_v<T>) {
82 return dpctl::tensor::math_utils::less_complex(lhs, rhs);
83 }
84 else {
85 return std::less{}(lhs, rhs);
86 }
87 }
88};
89
90template <typename T>
91struct IsNan
92{
93 static bool isnan(const T &v)
94 {
95 if constexpr (type_utils::is_complex_v<T>) {
96 using vT = typename T::value_type;
97
98 const vT real1 = std::real(v);
99 const vT imag1 = std::imag(v);
100
101 return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
102 }
103 else if constexpr (std::is_floating_point_v<T> ||
104 std::is_same_v<T, sycl::half>) {
105 return sycl::isnan(v);
106 }
107
108 return false;
109 }
110};
111
112template <typename T, bool hasValueType>
114
115template <typename T>
116struct value_type_of_impl<T, false>
117{
118 using type = T;
119};
120
121template <typename T>
122struct value_type_of_impl<T, true>
123{
124 using type = typename T::value_type;
125};
126
127template <typename T>
129
130template <typename T>
131using value_type_of_t = typename value_type_of<T>::type;
132
133size_t get_max_local_size(const sycl::device &device);
134size_t get_max_local_size(const sycl::device &device,
135 int cpu_local_size_limit,
136 int gpu_local_size_limit);
137
138inline size_t get_max_local_size(const sycl::queue &queue)
139{
140 return get_max_local_size(queue.get_device());
141}
142
143inline size_t get_max_local_size(const sycl::queue &queue,
144 int cpu_local_size_limit,
145 int gpu_local_size_limit)
146{
147 return get_max_local_size(queue.get_device(), cpu_local_size_limit,
148 gpu_local_size_limit);
149}
150
151size_t get_local_mem_size_in_bytes(const sycl::device &device);
152size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve);
153
154inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue)
155{
156 return get_local_mem_size_in_bytes(queue.get_device());
157}
158
159inline size_t get_local_mem_size_in_bytes(const sycl::queue &queue,
160 size_t reserve)
161{
162 return get_local_mem_size_in_bytes(queue.get_device(), reserve);
163}
164
165template <typename T>
166size_t get_local_mem_size_in_items(const sycl::device &device)
167{
168 return get_local_mem_size_in_bytes(device) / sizeof(T);
169}
170
171template <typename T>
172size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
173{
174 return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
175}
176
177template <typename T>
178inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
179{
180 return get_local_mem_size_in_items<T>(queue.get_device());
181}
182
183template <typename T>
184inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
185 size_t reserve)
186{
187 return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
188}
189
190template <int Dims>
191sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
192 const sycl::range<Dims> &local_range,
193 const sycl::range<Dims> &work_per_item)
194{
195 sycl::range<Dims> aligned_global_range;
196
197 for (int i = 0; i < Dims; ++i) {
198 aligned_global_range[i] =
199 Align(CeilDiv(global_range[i], work_per_item[i]), local_range[i]);
200 }
201
202 return sycl::nd_range<Dims>(aligned_global_range, local_range);
203}
204
205sycl::nd_range<1>
206 make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
207
208// This function is a copy from dpctl because it is not available in the public
209// headers of dpctl.
210pybind11::dtype dtype_from_typenum(int dst_typenum);
211
212template <typename dispatchT,
213 template <typename fnT, typename T>
214 typename factoryT,
215 int _num_types = type_dispatch::num_types>
216inline void init_dispatch_vector(dispatchT dispatch_vector[])
217{
218 type_dispatch::DispatchVectorBuilder<dispatchT, factoryT, _num_types> dvb;
219 dvb.populate_dispatch_vector(dispatch_vector);
220}
221
222template <typename dispatchT,
223 template <typename fnT, typename D, typename S>
224 typename factoryT,
225 int _num_types = type_dispatch::num_types>
226inline void init_dispatch_table(dispatchT dispatch_table[][_num_types])
227{
228 type_dispatch::DispatchTableBuilder<dispatchT, factoryT, _num_types> dtb;
229 dtb.populate_dispatch_table(dispatch_table);
230}
231} // namespace ext::common
232
233#include "ext/details/common_internal.hpp"