46 using descr_type = mkl_dft::descriptor<prec, dom>;
50 : descr_(dimensions), queue_ptr_{}
55 void commit(sycl::queue &q)
57 mkl_dft::precision fft_prec = get_precision();
58 if (fft_prec == mkl_dft::precision::DOUBLE &&
59 !q.get_device().has(sycl::aspect::fp64)) {
60 throw py::value_error(
"Descriptor is double precision but the "
61 "device does not support double precision.");
65 queue_ptr_ = std::make_unique<sycl::queue>(q);
68 descr_type &get_descriptor() {
return descr_; }
70 const sycl::queue &get_queue()
const
76 throw std::runtime_error(
77 "Attempt to get queue when it is not yet set");
82 template <
typename valT = std::
int64_t>
86 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
92 template <
typename valT = std::
int64_t>
93 const valT get_number_of_transforms()
95 valT transforms_count{};
97 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
99 return transforms_count;
102 template <
typename valT = std::
int64_t>
103 void set_number_of_transforms(
const valT &num)
105 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
109 template <
typename valT = std::vector<std::
int64_t>>
110 const valT get_fwd_strides()
112 const typename valT::value_type dim = get_dim();
114 valT fwd_strides(dim + 1);
115#if INTEL_MKL_VERSION >= 20250000
116 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
118 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
124 template <
typename valT = std::vector<std::
int64_t>>
125 void set_fwd_strides(
const valT &strides)
127 const typename valT::value_type dim = get_dim();
129 if (
static_cast<size_t>(dim + 1) != strides.size()) {
130 throw py::value_error(
131 "Strides length does not match descriptor's dimension");
133#if INTEL_MKL_VERSION >= 20250000
134 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
136 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
141 template <
typename valT = std::vector<std::
int64_t>>
142 const valT get_bwd_strides()
144 const typename valT::value_type dim = get_dim();
146 valT bwd_strides(dim + 1);
147#if INTEL_MKL_VERSION >= 20250000
148 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
150 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
156 template <
typename valT = std::vector<std::
int64_t>>
157 void set_bwd_strides(
const valT &strides)
159 const typename valT::value_type dim = get_dim();
161 if (
static_cast<size_t>(dim + 1) != strides.size()) {
162 throw py::value_error(
163 "Strides length does not match descriptor's dimension");
165#if INTEL_MKL_VERSION >= 20250000
166 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
168 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
173 template <
typename valT = std::
int64_t>
174 const valT get_fwd_distance()
178 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
182 template <
typename valT = std::
int64_t>
183 void set_fwd_distance(
const valT &dist)
185 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
189 template <
typename valT = std::
int64_t>
190 const valT get_bwd_distance()
194 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
198 template <
typename valT = std::
int64_t>
199 void set_bwd_distance(
const valT &dist)
201 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
207#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
208 mkl_dft::config_value placement;
209 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
210 return (placement == mkl_dft::config_value::INPLACE);
213 DFTI_CONFIG_VALUE placement;
214 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
215 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
219 void set_in_place(
const bool &in_place_request)
221#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
222 descr_.set_value(mkl_dft::config_param::PLACEMENT,
224 ? mkl_dft::config_value::INPLACE
225 : mkl_dft::config_value::NOT_INPLACE);
228 descr_.set_value(mkl_dft::config_param::PLACEMENT,
230 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
231 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
236 mkl_dft::precision get_precision()
238 mkl_dft::precision fft_prec;
240 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
247#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
248 mkl_dft::config_value committed;
249 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
250 return (committed == mkl_dft::config_value::COMMITTED);
253 DFTI_CONFIG_VALUE committed;
254 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
255 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
260 mkl_dft::descriptor<prec, dom> descr_;
261 std::unique_ptr<sycl::queue> queue_ptr_;