DPNP C++ backend kernel library 0.18.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <stdexcept>
29
30#include <oneapi/mkl.hpp>
31#include <pybind11/pybind11.h>
32#include <sycl/sycl.hpp>
33
34namespace dpnp::extensions::fft
35{
36namespace mkl_dft = oneapi::mkl::dft;
37namespace py = pybind11;
38
39template <mkl_dft::precision prec, mkl_dft::domain dom>
41{
42public:
43 using descr_type = mkl_dft::descriptor<prec, dom>;
44
45 DescriptorWrapper(std::int64_t n) : descr_(n), queue_ptr_{} {}
46 DescriptorWrapper(std::vector<std::int64_t> dimensions)
47 : descr_(dimensions), queue_ptr_{}
48 {
49 }
51
52 void commit(sycl::queue &q)
53 {
54 mkl_dft::precision fft_prec = get_precision();
55 if (fft_prec == mkl_dft::precision::DOUBLE &&
56 !q.get_device().has(sycl::aspect::fp64))
57 {
58 throw py::value_error("Descriptor is double precision but the "
59 "device does not support double precision.");
60 }
61
62 descr_.commit(q);
63 queue_ptr_ = std::make_unique<sycl::queue>(q);
64 }
65
66 descr_type &get_descriptor()
67 {
68 return descr_;
69 }
70
71 const sycl::queue &get_queue() const
72 {
73 if (queue_ptr_) {
74 return *queue_ptr_;
75 }
76 else {
77 throw std::runtime_error(
78 "Attempt to get queue when it is not yet set");
79 }
80 }
81
82 // config_param::DIMENSION
83 template <typename valT = std::int64_t>
84 const valT get_dim()
85 {
86 valT dim = -1;
87 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
88
89 return dim;
90 }
91
92 // config_param::NUMBER_OF_TRANSFORMS
93 template <typename valT = std::int64_t>
94 const valT get_number_of_transforms()
95 {
96 valT transforms_count{};
97
98 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
99 &transforms_count);
100 return transforms_count;
101 }
102
103 template <typename valT = std::int64_t>
104 void set_number_of_transforms(const valT &num)
105 {
106 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
107 }
108
109 // config_param::FWD_STRIDES
110 template <typename valT = std::vector<std::int64_t>>
111 const valT get_fwd_strides()
112 {
113 const typename valT::value_type dim = get_dim();
114
115 valT fwd_strides(dim + 1);
116#if INTEL_MKL_VERSION >= 20250000
117 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
118#else
119 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
120 fwd_strides.data());
121#endif // INTEL_MKL_VERSION
122 return fwd_strides;
123 }
124
125 template <typename valT = std::vector<std::int64_t>>
126 void set_fwd_strides(const valT &strides)
127 {
128 const typename valT::value_type dim = get_dim();
129
130 if (static_cast<size_t>(dim + 1) != strides.size()) {
131 throw py::value_error(
132 "Strides length does not match descriptor's dimension");
133 }
134#if INTEL_MKL_VERSION >= 20250000
135 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
136#else
137 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
138#endif // INTEL_MKL_VERSION
139 }
140
141 // config_param::BWD_STRIDES
142 template <typename valT = std::vector<std::int64_t>>
143 const valT get_bwd_strides()
144 {
145 const typename valT::value_type dim = get_dim();
146
147 valT bwd_strides(dim + 1);
148#if INTEL_MKL_VERSION >= 20250000
149 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
150#else
151 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
152 bwd_strides.data());
153#endif // INTEL_MKL_VERSION
154 return bwd_strides;
155 }
156
157 template <typename valT = std::vector<std::int64_t>>
158 void set_bwd_strides(const valT &strides)
159 {
160 const typename valT::value_type dim = get_dim();
161
162 if (static_cast<size_t>(dim + 1) != strides.size()) {
163 throw py::value_error(
164 "Strides length does not match descriptor's dimension");
165 }
166#if INTEL_MKL_VERSION >= 20250000
167 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
168#else
169 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
170#endif // INTEL_MKL_VERSION
171 }
172
173 // config_param::FWD_DISTANCE
174 template <typename valT = std::int64_t>
175 const valT get_fwd_distance()
176 {
177 valT dist = 0;
178
179 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
180 return dist;
181 }
182
183 template <typename valT = std::int64_t>
184 void set_fwd_distance(const valT &dist)
185 {
186 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
187 }
188
189 // config_param::BWD_DISTANCE
190 template <typename valT = std::int64_t>
191 const valT get_bwd_distance()
192 {
193 valT dist = 0;
194
195 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
196 return dist;
197 }
198
199 template <typename valT = std::int64_t>
200 void set_bwd_distance(const valT &dist)
201 {
202 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
203 }
204
205 // config_param::PLACEMENT
206 bool get_in_place()
207 {
208#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
209 mkl_dft::config_value placement;
210 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
211 return (placement == mkl_dft::config_value::INPLACE);
212#else
213 // TODO: remove branch when MKLD-10506 is implemented
214 DFTI_CONFIG_VALUE placement;
215 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
216 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
217#endif // USE_ONEMKL_INTERFACES or INTEL_MKL_VERSION
218 }
219
220 void set_in_place(const bool &in_place_request)
221 {
222#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
223 descr_.set_value(mkl_dft::config_param::PLACEMENT,
224 (in_place_request)
225 ? mkl_dft::config_value::INPLACE
226 : mkl_dft::config_value::NOT_INPLACE);
227#else
228 // TODO: remove branch when MKLD-10506 is implemented
229 descr_.set_value(mkl_dft::config_param::PLACEMENT,
230 (in_place_request)
231 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
232 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
233#endif // USE_ONEMKL_INTERFACES or INTEL_MKL_VERSION
234 }
235
236 // config_param::PRECISION
237 mkl_dft::precision get_precision()
238 {
239 mkl_dft::precision fft_prec;
240
241 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
242 return fft_prec;
243 }
244
245 // config_param::COMMIT_STATUS
246 bool is_committed()
247 {
248#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
249 mkl_dft::config_value committed;
250 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
251 return (committed == mkl_dft::config_value::COMMITTED);
252#else
253 // TODO: remove branch when MKLD-10506 is implemented
254 DFTI_CONFIG_VALUE committed;
255 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
256 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
257#endif // USE_ONEMKL_INTERFACES or INTEL_MKL_VERSION
258 }
259
260private:
261 mkl_dft::descriptor<prec, dom> descr_;
262 std::unique_ptr<sycl::queue> queue_ptr_;
263};
264
265} // namespace dpnp::extensions::fft