46 using descr_type = mkl_dft::descriptor<prec, dom>;
50 : descr_(dimensions), queue_ptr_{}
55 void commit(sycl::queue &q)
57 mkl_dft::precision fft_prec = get_precision();
58 if (fft_prec == mkl_dft::precision::DOUBLE &&
59 !q.get_device().has(sycl::aspect::fp64))
61 throw py::value_error(
"Descriptor is double precision but the "
62 "device does not support double precision.");
66 queue_ptr_ = std::make_unique<sycl::queue>(q);
69 descr_type &get_descriptor()
74 const sycl::queue &get_queue()
const
80 throw std::runtime_error(
81 "Attempt to get queue when it is not yet set");
86 template <
typename valT = std::
int64_t>
90 descr_.get_value(mkl_dft::config_param::DIMENSION, &dim);
96 template <
typename valT = std::
int64_t>
97 const valT get_number_of_transforms()
99 valT transforms_count{};
101 descr_.get_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS,
103 return transforms_count;
106 template <
typename valT = std::
int64_t>
107 void set_number_of_transforms(
const valT &num)
109 descr_.set_value(mkl_dft::config_param::NUMBER_OF_TRANSFORMS, num);
113 template <
typename valT = std::vector<std::
int64_t>>
114 const valT get_fwd_strides()
116 const typename valT::value_type dim = get_dim();
118 valT fwd_strides(dim + 1);
119#if INTEL_MKL_VERSION >= 20250000
120 descr_.get_value(mkl_dft::config_param::FWD_STRIDES, &fwd_strides);
122 descr_.get_value(mkl_dft::config_param::FWD_STRIDES,
128 template <
typename valT = std::vector<std::
int64_t>>
129 void set_fwd_strides(
const valT &strides)
131 const typename valT::value_type dim = get_dim();
133 if (
static_cast<size_t>(dim + 1) != strides.size()) {
134 throw py::value_error(
135 "Strides length does not match descriptor's dimension");
137#if INTEL_MKL_VERSION >= 20250000
138 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides);
140 descr_.set_value(mkl_dft::config_param::FWD_STRIDES, strides.data());
145 template <
typename valT = std::vector<std::
int64_t>>
146 const valT get_bwd_strides()
148 const typename valT::value_type dim = get_dim();
150 valT bwd_strides(dim + 1);
151#if INTEL_MKL_VERSION >= 20250000
152 descr_.get_value(mkl_dft::config_param::BWD_STRIDES, &bwd_strides);
154 descr_.get_value(mkl_dft::config_param::BWD_STRIDES,
160 template <
typename valT = std::vector<std::
int64_t>>
161 void set_bwd_strides(
const valT &strides)
163 const typename valT::value_type dim = get_dim();
165 if (
static_cast<size_t>(dim + 1) != strides.size()) {
166 throw py::value_error(
167 "Strides length does not match descriptor's dimension");
169#if INTEL_MKL_VERSION >= 20250000
170 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides);
172 descr_.set_value(mkl_dft::config_param::BWD_STRIDES, strides.data());
177 template <
typename valT = std::
int64_t>
178 const valT get_fwd_distance()
182 descr_.get_value(mkl_dft::config_param::FWD_DISTANCE, &dist);
186 template <
typename valT = std::
int64_t>
187 void set_fwd_distance(
const valT &dist)
189 descr_.set_value(mkl_dft::config_param::FWD_DISTANCE, dist);
193 template <
typename valT = std::
int64_t>
194 const valT get_bwd_distance()
198 descr_.get_value(mkl_dft::config_param::BWD_DISTANCE, &dist);
202 template <
typename valT = std::
int64_t>
203 void set_bwd_distance(
const valT &dist)
205 descr_.set_value(mkl_dft::config_param::BWD_DISTANCE, dist);
211#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
212 mkl_dft::config_value placement;
213 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
214 return (placement == mkl_dft::config_value::INPLACE);
217 DFTI_CONFIG_VALUE placement;
218 descr_.get_value(mkl_dft::config_param::PLACEMENT, &placement);
219 return (placement == DFTI_CONFIG_VALUE::DFTI_INPLACE);
223 void set_in_place(
const bool &in_place_request)
225#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
226 descr_.set_value(mkl_dft::config_param::PLACEMENT,
228 ? mkl_dft::config_value::INPLACE
229 : mkl_dft::config_value::NOT_INPLACE);
232 descr_.set_value(mkl_dft::config_param::PLACEMENT,
234 ? DFTI_CONFIG_VALUE::DFTI_INPLACE
235 : DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE);
240 mkl_dft::precision get_precision()
242 mkl_dft::precision fft_prec;
244 descr_.get_value(mkl_dft::config_param::PRECISION, &fft_prec);
251#if defined(USE_ONEMATH) || INTEL_MKL_VERSION >= 20250000
252 mkl_dft::config_value committed;
253 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
254 return (committed == mkl_dft::config_value::COMMITTED);
257 DFTI_CONFIG_VALUE committed;
258 descr_.get_value(mkl_dft::config_param::COMMIT_STATUS, &committed);
259 return (committed == DFTI_CONFIG_VALUE::DFTI_COMMITTED);
264 mkl_dft::descriptor<prec, dom> descr_;
265 std::unique_ptr<sycl::queue> queue_ptr_;