DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2023, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <type_traits>
32
33#include <oneapi/mkl.hpp>
34#include <sycl/sycl.hpp>
35
36#include <dpctl4pybind11.hpp>
37#include <pybind11/pybind11.h>
38
39// utils extension header
40#include "ext/common.hpp"
41
42// dpctl tensor headers
43#include "utils/memory_overlap.hpp"
44#include "utils/type_dispatch.hpp"
45
52#ifndef __INTEL_MKL_2023_2_0_VERSION_REQUIRED
53#define __INTEL_MKL_2023_2_0_VERSION_REQUIRED 20230002L
54#endif
55
56static_assert(INTEL_MKL_VERSION >= __INTEL_MKL_2023_2_0_VERSION_REQUIRED,
57 "OneMKL does not meet minimum version requirement");
58
59namespace ext_ns = ext::common;
60namespace py = pybind11;
61namespace td_ns = dpctl::tensor::type_dispatch;
62
63namespace dpnp::extensions::vm::py_internal
64{
65template <typename output_typesT, typename contig_dispatchT>
66bool need_to_call_unary_ufunc(sycl::queue &exec_q,
67 const dpctl::tensor::usm_ndarray &src,
68 const dpctl::tensor::usm_ndarray &dst,
69 const output_typesT &output_type_vec,
70 const contig_dispatchT &contig_dispatch_vector)
71{
72 // check type_nums
73 int src_typenum = src.get_typenum();
74 int dst_typenum = dst.get_typenum();
75
76 auto array_types = td_ns::usm_ndarray_types();
77 int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
78 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
79
80 // check that types are supported
81 int func_output_typeid = output_type_vec[src_typeid];
82 if (dst_typeid != func_output_typeid) {
83 return false;
84 }
85
86 // OneMKL VM functions perform a copy on host if no double type support
87 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
88 return false;
89 }
90
91 // check that queues are compatible
92 if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
93 return false;
94 }
95
96 // dimensions must be the same
97 int dst_nd = dst.get_ndim();
98 if (dst_nd != src.get_ndim()) {
99 return false;
100 }
101 else if (dst_nd == 0) {
102 // don't call OneMKL for 0d arrays
103 return false;
104 }
105
106 // shapes must be the same
107 const py::ssize_t *src_shape = src.get_shape_raw();
108 const py::ssize_t *dst_shape = dst.get_shape_raw();
109 bool shapes_equal(true);
110 size_t src_nelems(1);
111
112 for (int i = 0; i < dst_nd; ++i) {
113 src_nelems *= static_cast<size_t>(src_shape[i]);
114 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
115 }
116 if (!shapes_equal) {
117 return false;
118 }
119
120 // if nelems is zero, return false
121 if (src_nelems == 0) {
122 return false;
123 }
124
125 // ensure that output is ample enough to accommodate all elements
126 auto dst_offsets = dst.get_minmax_offsets();
127 // destination must be ample enough to accommodate all elements
128 {
129 size_t range =
130 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
131 if (range + 1 < src_nelems) {
132 return false;
133 }
134 }
135
136 // check memory overlap
137 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
138 if (overlap(src, dst)) {
139 return false;
140 }
141
142 // support only contiguous inputs
143 bool is_src_c_contig = src.is_c_contiguous();
144 bool is_dst_c_contig = dst.is_c_contiguous();
145
146 bool all_c_contig = (is_src_c_contig && is_dst_c_contig);
147 if (!all_c_contig) {
148 return false;
149 }
150
151 // MKL function is not defined for the type
152 if (contig_dispatch_vector[src_typeid] == nullptr) {
153 return false;
154 }
155 return true;
156}
157
158template <typename output_typesT, typename contig_dispatchT>
159bool need_to_call_binary_ufunc(sycl::queue &exec_q,
160 const dpctl::tensor::usm_ndarray &src1,
161 const dpctl::tensor::usm_ndarray &src2,
162 const dpctl::tensor::usm_ndarray &dst,
163 const output_typesT &output_type_table,
164 const contig_dispatchT &contig_dispatch_table)
165{
166 // check type_nums
167 int src1_typenum = src1.get_typenum();
168 int src2_typenum = src2.get_typenum();
169 int dst_typenum = dst.get_typenum();
170
171 auto array_types = td_ns::usm_ndarray_types();
172 int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
173 int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
174 int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
175
176 // check that types are supported
177 int output_typeid = output_type_table[src1_typeid][src2_typeid];
178 if (output_typeid != dst_typeid) {
179 return false;
180 }
181
182 // types must be the same
183 if (src1_typeid != src2_typeid) {
184 return false;
185 }
186
187 // OneMKL VM functions perform a copy on host if no double type support
188 if (!exec_q.get_device().has(sycl::aspect::fp64)) {
189 return false;
190 }
191
192 // check that queues are compatible
193 if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) {
194 return false;
195 }
196
197 // dimensions must be the same
198 int dst_nd = dst.get_ndim();
199 if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) {
200 return false;
201 }
202 else if (dst_nd == 0) {
203 // don't call OneMKL for 0d arrays
204 return false;
205 }
206
207 // shapes must be the same
208 const py::ssize_t *src1_shape = src1.get_shape_raw();
209 const py::ssize_t *src2_shape = src2.get_shape_raw();
210 const py::ssize_t *dst_shape = dst.get_shape_raw();
211 bool shapes_equal(true);
212 size_t src_nelems(1);
213
214 for (int i = 0; i < dst_nd; ++i) {
215 src_nelems *= static_cast<size_t>(src1_shape[i]);
216 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
217 src2_shape[i] == dst_shape[i]);
218 }
219 if (!shapes_equal) {
220 return false;
221 }
222
223 // if nelems is zero, return false
224 if (src_nelems == 0) {
225 return false;
226 }
227
228 // ensure that output is ample enough to accommodate all elements
229 auto dst_offsets = dst.get_minmax_offsets();
230 // destination must be ample enough to accommodate all elements
231 {
232 size_t range =
233 static_cast<size_t>(dst_offsets.second - dst_offsets.first);
234 if (range + 1 < src_nelems) {
235 return false;
236 }
237 }
238
239 // check memory overlap
240 auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
241 if (overlap(src1, dst) || overlap(src2, dst)) {
242 return false;
243 }
244
245 // support only contiguous inputs
246 bool is_src1_c_contig = src1.is_c_contiguous();
247 bool is_src2_c_contig = src2.is_c_contiguous();
248 bool is_dst_c_contig = dst.is_c_contiguous();
249
250 bool all_c_contig =
251 (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig);
252 if (!all_c_contig) {
253 return false;
254 }
255
256 // MKL function is not defined for the type
257 if (contig_dispatch_table[src1_typeid] == nullptr) {
258 return false;
259 }
260 return true;
261}
262
268#define MACRO_POPULATE_DISPATCH_VECTORS(__name__) \
269 template <typename fnT, typename T> \
270 struct ContigFactory \
271 { \
272 fnT get() \
273 { \
274 if constexpr (std::is_same_v<typename OutputType<T>::value_type, \
275 void>) { \
276 return nullptr; \
277 } \
278 else { \
279 return __name__##_contig_impl<T>; \
280 } \
281 } \
282 }; \
283 \
284 template <typename fnT, typename T> \
285 struct TypeMapFactory \
286 { \
287 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
288 { \
289 using rT = typename OutputType<T>::value_type; \
290 return td_ns::GetTypeid<rT>{}.get(); \
291 } \
292 }; \
293 \
294 static void populate_dispatch_vectors(void) \
295 { \
296 ext_ns::init_dispatch_vector<int, TypeMapFactory>( \
297 output_typeid_vector); \
298 ext_ns::init_dispatch_vector<unary_contig_impl_fn_ptr_t, \
299 ContigFactory>(contig_dispatch_vector); \
300 };
301
307#define MACRO_POPULATE_DISPATCH_TABLES(__name__) \
308 template <typename fnT, typename T1, typename T2> \
309 struct ContigFactory \
310 { \
311 fnT get() \
312 { \
313 if constexpr (std::is_same_v< \
314 typename OutputType<T1, T2>::value_type, void>) \
315 { \
316 return nullptr; \
317 } \
318 else { \
319 return __name__##_contig_impl<T1, T2>; \
320 } \
321 } \
322 }; \
323 \
324 template <typename fnT, typename T1, typename T2> \
325 struct TypeMapFactory \
326 { \
327 std::enable_if_t<std::is_same<fnT, int>::value, int> get() \
328 { \
329 using rT = typename OutputType<T1, T2>::value_type; \
330 return td_ns::GetTypeid<rT>{}.get(); \
331 } \
332 }; \
333 \
334 static void populate_dispatch_tables(void) \
335 { \
336 ext_ns::init_dispatch_table<int, TypeMapFactory>( \
337 output_typeid_vector); \
338 ext_ns::init_dispatch_table<binary_contig_impl_fn_ptr_t, \
339 ContigFactory>(contig_dispatch_vector); \
340 };
341} // namespace dpnp::extensions::vm::py_internal