DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <stdexcept>
32
33#include <oneapi/mkl.hpp>
34#include <pybind11/pybind11.h>
35#include <sycl/sycl.hpp>
36
37namespace dpnp::extensions::fft
38{
39namespace mkl_dft = oneapi::mkl::dft;
40namespace py = pybind11;
41
42template <mkl_dft::precision prec, mkl_dft::domain dom>
44{
45public:
46 using descr_type = mkl_dft::descriptor<prec, dom>;
47
48 DescriptorWrapper(std::int64_t n) : descr_(n), queue_ptr_{} {}
49 DescriptorWrapper(std::vector<std::int64_t> dimensions)
50 : descr_(dimensions), queue_ptr_{}
51 {
52 }
54
55 void commit(sycl::queue &q)
56 {
57 mkl_dft::precision fft_prec = get_precision();
58 if (fft_prec == mkl_dft::precision::DOUBLE &&
59 !q.get_device().has(sycl::aspect::fp64)) {
60 throw py::value_error("Descriptor is double precision but the "
61 "device does not support double precision.");
62 }
63
64 descr_.commit(q);
65 queue_ptr_ = std::make_unique<sycl::queue>(q);
66 }
67
68 descr_type &get_descriptor() { return descr_; }
69
70 const sycl::queue &get_queue() const
71 {
72 if (queue_ptr_) {
73 return *queue_ptr_;
74 }
75 else {
76 throw std::runtime_error(
77 "Attempt to get queue when it is not yet set");
78 }
79 }
80
81 // config_param::DIMENSION
82 template <typename valT = std::int64_t>
83 const valT get_dim()
84 {
85 valT dim = -1;
86 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
87
88 return dim;
89 }
90
91 // config_param::NUMBER_OF_TRANSFORMS
92 template <typename valT = std::int64_t>
93 const valT get_number_of_transforms()
94 {
95 valT transforms_count{};
96
97 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
98 &transforms_count);
99 return transforms_count;
100 }
101
102 template <typename valT = std::int64_t>
103 void set_number_of_transforms(const valT &num)
104 {
105 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
106 }
107
108 // config_param::FWD_STRIDES
109 template <typename valT = std::vector<std::int64_t>>
110 const valT get_fwd_strides()
111 {
112 const typename valT::value_type dim = get_dim();
113
114 valT fwd_strides(dim + 1);
115#if INTEL_MKL_VERSION >= 20250000
116 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
117#else
118 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
119 fwd_strides.data());
120#endif // INTEL_MKL_VERSION
121 return fwd_strides;
122 }
123
124 template <typename valT = std::vector<std::int64_t>>
125 void set_fwd_strides(const valT &strides)
126 {
127 const typename valT::value_type dim = get_dim();
128
129 if (static_cast<size_t>(dim + 1) != strides.size()) {
130 throw py::value_error(
131 "Strides length does not match descriptor's dimension");
132 }
133#if INTEL_MKL_VERSION >= 20250000
134 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
135#else
136 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
137#endif // INTEL_MKL_VERSION
138 }
139
140 // config_param::BWD_STRIDES
141 template <typename valT = std::vector<std::int64_t>>
142 const valT get_bwd_strides()
143 {
144 const typename valT::value_type dim = get_dim();
145
146 valT bwd_strides(dim + 1);
147#if INTEL_MKL_VERSION >= 20250000
148 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
149#else
150 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
151 bwd_strides.data());
152#endif // INTEL_MKL_VERSION
153 return bwd_strides;
154 }
155
156 template <typename valT = std::vector<std::int64_t>>
157 void set_bwd_strides(const valT &strides)
158 {
159 const typename valT::value_type dim = get_dim();
160
161 if (static_cast<size_t>(dim + 1) != strides.size()) {
162 throw py::value_error(
163 "Strides length does not match descriptor's dimension");
164 }
165#if INTEL_MKL_VERSION >= 20250000
166 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
167#else
168 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
169#endif // INTEL_MKL_VERSION
170 }
171
172 // config_param::FWD_DISTANCE
173 template <typename valT = std::int64_t>
174 const valT get_fwd_distance()
175 {
176 valT dist = 0;
177
178 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
179 return dist;
180 }
181
182 template <typename valT = std::int64_t>
183 void set_fwd_distance(const valT &dist)
184 {
185 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
186 }
187
188 // config_param::BWD_DISTANCE
189 template <typename valT = std::int64_t>
190 const valT get_bwd_distance()
191 {
192 valT dist = 0;
193
194 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
195 return dist;
196 }
197
198 template <typename valT = std::int64_t>
199 void set_bwd_distance(const valT &dist)
200 {
201 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
202 }
203
204 // config_param::PLACEMENT
205 bool get_in_place()
206 {
207#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
208 mkl_dft::config_value placement;
209 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
210 return (placement == mkl_dft::config_value::INPLACE);
211#else
212 // TODO: remove branch when MKLD-10506 is implemented
213 DFTI_CONFIG_VALUE placement;
214 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
215 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
216#endif // USE_ONEMATH or INTEL_MKL_VERSION
217 }
218
219 void set_in_place(const bool &in_place_request)
220 {
221#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
222 descr_.set_value(mkl_dft::config_param::PLACEMENT,
223 (in_place_request)
224 ? mkl_dft::config_value::INPLACE
225 : mkl_dft::config_value::NOT_INPLACE);
226#else
227 // TODO: remove branch when MKLD-10506 is implemented
228 descr_.set_value(mkl_dft::config_param::PLACEMENT,
229 (in_place_request)
230 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
231 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
232#endif // USE_ONEMATH or INTEL_MKL_VERSION
233 }
234
235 // config_param::PRECISION
236 mkl_dft::precision get_precision()
237 {
238 mkl_dft::precision fft_prec;
239
240 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
241 return fft_prec;
242 }
243
244 // config_param::COMMIT_STATUS
245 bool is_committed()
246 {
247#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
248 mkl_dft::config_value committed;
249 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
250 return (committed == mkl_dft::config_value::COMMITTED);
251#else
252 // TODO: remove branch when MKLD-10506 is implemented
253 DFTI_CONFIG_VALUE committed;
254 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
255 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
256#endif // USE_ONEMATH or INTEL_MKL_VERSION
257 }
258
259private:
260 mkl_dft::descriptor<prec, dom> descr_;
261 std::unique_ptr<sycl::queue> queue_ptr_;
262};
263
264} // namespace dpnp::extensions::fft