DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <stdexcept>
32
33#include <oneapi/mkl.hpp>
34#include <pybind11/pybind11.h>
35#include <sycl/sycl.hpp>
36
37namespace dpnp::extensions::fft
38{
39namespace mkl_dft = oneapi::mkl::dft;
40namespace py = pybind11;
41
42template <mkl_dft::precision prec, mkl_dft::domain dom>
44{
45public:
46 using descr_type = mkl_dft::descriptor<prec, dom>;
47
48 DescriptorWrapper(std::int64_t n) : descr_(n), queue_ptr_{} {}
49 DescriptorWrapper(std::vector<std::int64_t> dimensions)
50 : descr_(dimensions), queue_ptr_{}
51 {
52 }
54
55 void commit(sycl::queue &q)
56 {
57 mkl_dft::precision fft_prec = get_precision();
58 if (fft_prec == mkl_dft::precision::DOUBLE &&
59 !q.get_device().has(sycl::aspect::fp64))
60 {
61 throw py::value_error("Descriptor is double precision but the "
62 "device does not support double precision.");
63 }
64
65 descr_.commit(q);
66 queue_ptr_ = std::make_unique<sycl::queue>(q);
67 }
68
69 descr_type &get_descriptor()
70 {
71 return descr_;
72 }
73
74 const sycl::queue &get_queue() const
75 {
76 if (queue_ptr_) {
77 return *queue_ptr_;
78 }
79 else {
80 throw std::runtime_error(
81 "Attempt to get queue when it is not yet set");
82 }
83 }
84
85 // config_param::DIMENSION
86 template <typename valT = std::int64_t>
87 const valT get_dim()
88 {
89 valT dim = -1;
90 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
91
92 return dim;
93 }
94
95 // config_param::NUMBER_OF_TRANSFORMS
96 template <typename valT = std::int64_t>
97 const valT get_number_of_transforms()
98 {
99 valT transforms_count{};
100
101 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
102 &transforms_count);
103 return transforms_count;
104 }
105
106 template <typename valT = std::int64_t>
107 void set_number_of_transforms(const valT &num)
108 {
109 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
110 }
111
112 // config_param::FWD_STRIDES
113 template <typename valT = std::vector<std::int64_t>>
114 const valT get_fwd_strides()
115 {
116 const typename valT::value_type dim = get_dim();
117
118 valT fwd_strides(dim + 1);
119#if INTEL_MKL_VERSION >= 20250000
120 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
121#else
122 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
123 fwd_strides.data());
124#endif // INTEL_MKL_VERSION
125 return fwd_strides;
126 }
127
128 template <typename valT = std::vector<std::int64_t>>
129 void set_fwd_strides(const valT &strides)
130 {
131 const typename valT::value_type dim = get_dim();
132
133 if (static_cast<size_t>(dim + 1) != strides.size()) {
134 throw py::value_error(
135 "Strides length does not match descriptor's dimension");
136 }
137#if INTEL_MKL_VERSION >= 20250000
138 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
139#else
140 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
141#endif // INTEL_MKL_VERSION
142 }
143
144 // config_param::BWD_STRIDES
145 template <typename valT = std::vector<std::int64_t>>
146 const valT get_bwd_strides()
147 {
148 const typename valT::value_type dim = get_dim();
149
150 valT bwd_strides(dim + 1);
151#if INTEL_MKL_VERSION >= 20250000
152 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
153#else
154 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
155 bwd_strides.data());
156#endif // INTEL_MKL_VERSION
157 return bwd_strides;
158 }
159
160 template <typename valT = std::vector<std::int64_t>>
161 void set_bwd_strides(const valT &strides)
162 {
163 const typename valT::value_type dim = get_dim();
164
165 if (static_cast<size_t>(dim + 1) != strides.size()) {
166 throw py::value_error(
167 "Strides length does not match descriptor's dimension");
168 }
169#if INTEL_MKL_VERSION >= 20250000
170 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
171#else
172 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
173#endif // INTEL_MKL_VERSION
174 }
175
176 // config_param::FWD_DISTANCE
177 template <typename valT = std::int64_t>
178 const valT get_fwd_distance()
179 {
180 valT dist = 0;
181
182 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
183 return dist;
184 }
185
186 template <typename valT = std::int64_t>
187 void set_fwd_distance(const valT &dist)
188 {
189 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
190 }
191
192 // config_param::BWD_DISTANCE
193 template <typename valT = std::int64_t>
194 const valT get_bwd_distance()
195 {
196 valT dist = 0;
197
198 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
199 return dist;
200 }
201
202 template <typename valT = std::int64_t>
203 void set_bwd_distance(const valT &dist)
204 {
205 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
206 }
207
208 // config_param::PLACEMENT
209 bool get_in_place()
210 {
211#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
212 mkl_dft::config_value placement;
213 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
214 return (placement == mkl_dft::config_value::INPLACE);
215#else
216 // TODO: remove branch when MKLD-10506 is implemented
217 DFTI_CONFIG_VALUE placement;
218 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
219 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
220#endif // USE_ONEMATH or INTEL_MKL_VERSION
221 }
222
223 void set_in_place(const bool &in_place_request)
224 {
225#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
226 descr_.set_value(mkl_dft::config_param::PLACEMENT,
227 (in_place_request)
228 ? mkl_dft::config_value::INPLACE
229 : mkl_dft::config_value::NOT_INPLACE);
230#else
231 // TODO: remove branch when MKLD-10506 is implemented
232 descr_.set_value(mkl_dft::config_param::PLACEMENT,
233 (in_place_request)
234 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
235 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
236#endif // USE_ONEMATH or INTEL_MKL_VERSION
237 }
238
239 // config_param::PRECISION
240 mkl_dft::precision get_precision()
241 {
242 mkl_dft::precision fft_prec;
243
244 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
245 return fft_prec;
246 }
247
248 // config_param::COMMIT_STATUS
249 bool is_committed()
250 {
251#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
252 mkl_dft::config_value committed;
253 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
254 return (committed == mkl_dft::config_value::COMMITTED);
255#else
256 // TODO: remove branch when MKLD-10506 is implemented
257 DFTI_CONFIG_VALUE committed;
258 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
259 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
260#endif // USE_ONEMATH or INTEL_MKL_VERSION
261 }
262
263private:
264 mkl_dft::descriptor<prec, dom> descr_;
265 std::unique_ptr<sycl::queue> queue_ptr_;
266};
267
268} // namespace dpnp::extensions::fft