43 using descr_type = mkl_dft::descriptor<prec, dom>;
47 : descr_(dimensions), queue_ptr_{}
52 void commit(sycl::queue &q)
54 mkl_dft::precision fft_prec = get_precision();
55 if (fft_prec == mkl_dft::precision::DOUBLE &&
56 !q.get_device().has(sycl::aspect::fp64))
58 throw py::value_error(
"Descriptor is double precision but the "
59 "device does not support double precision.");
63 queue_ptr_ = std::make_unique<sycl::queue>(q);
66 descr_type &get_descriptor()
71 const sycl::queue &get_queue()
const
77 throw std::runtime_error(
78 "Attempt to get queue when it is not yet set");
83 template <
typename valT = std::
int64_t>
87 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
93 template <
typename valT = std::
int64_t>
94 const valT get_number_of_transforms()
96 valT transforms_count{};
98 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
100 return transforms_count;
103 template <
typename valT = std::
int64_t>
104 void set_number_of_transforms(
const valT &num)
106 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
110 template <
typename valT = std::vector<std::
int64_t>>
111 const valT get_fwd_strides()
113 const typename valT::value_type dim = get_dim();
115 valT fwd_strides(dim + 1);
116#if INTEL_MKL_VERSION >= 20250000
117 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
119 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
125 template <
typename valT = std::vector<std::
int64_t>>
126 void set_fwd_strides(
const valT &strides)
128 const typename valT::value_type dim = get_dim();
130 if (
static_cast<size_t>(dim + 1) != strides.size()) {
131 throw py::value_error(
132 "Strides length does not match descriptor's dimension");
134#if INTEL_MKL_VERSION >= 20250000
135 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
137 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
142 template <
typename valT = std::vector<std::
int64_t>>
143 const valT get_bwd_strides()
145 const typename valT::value_type dim = get_dim();
147 valT bwd_strides(dim + 1);
148#if INTEL_MKL_VERSION >= 20250000
149 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
151 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
157 template <
typename valT = std::vector<std::
int64_t>>
158 void set_bwd_strides(
const valT &strides)
160 const typename valT::value_type dim = get_dim();
162 if (
static_cast<size_t>(dim + 1) != strides.size()) {
163 throw py::value_error(
164 "Strides length does not match descriptor's dimension");
166#if INTEL_MKL_VERSION >= 20250000
167 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
169 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
174 template <
typename valT = std::
int64_t>
175 const valT get_fwd_distance()
179 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
183 template <
typename valT = std::
int64_t>
184 void set_fwd_distance(
const valT &dist)
186 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
190 template <
typename valT = std::
int64_t>
191 const valT get_bwd_distance()
195 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
199 template <
typename valT = std::
int64_t>
200 void set_bwd_distance(
const valT &dist)
202 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
208#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
209 mkl_dft::config_value placement;
210 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
211 return (placement == mkl_dft::config_value::INPLACE);
214 DFTI_CONFIG_VALUE placement;
215 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
216 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
220 void set_in_place(
const bool &in_place_request)
222#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
223 descr_.set_value(mkl_dft::config_param::PLACEMENT,
225 ? mkl_dft::config_value::INPLACE
226 : mkl_dft::config_value::NOT_INPLACE);
229 descr_.set_value(mkl_dft::config_param::PLACEMENT,
231 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
232 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
237 mkl_dft::precision get_precision()
239 mkl_dft::precision fft_prec;
241 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
248#if defined(USE_ONEMKL_INTERFACES) || INTEL_MKL_VERSION >= 20250000
249 mkl_dft::config_value committed;
250 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
251 return (committed == mkl_dft::config_value::COMMITTED);
254 DFTI_CONFIG_VALUE committed;
255 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
256 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
261 mkl_dft::descriptor<prec, dom> descr_;
262 std::unique_ptr<sycl::queue> queue_ptr_;