36#include <sycl/sycl.hpp>
38#include "kernels/alignment.hpp"
39#include "kernels/dpctl_tensor_types.hpp"
40#include "kernels/elementwise_functions/sycl_complex.hpp"
41#include "utils/offset_utils.hpp"
42#include "utils/sycl_utils.hpp"
43#include "utils/type_utils.hpp"
45namespace dpnp::kernels::isclose
49inline bool isclose(
const T a,
55 static_assert(std::is_floating_point_v<T> || std::is_same_v<T, sycl::half>);
57 if (sycl::isfinite(a) && sycl::isfinite(b)) {
58 return sycl::fabs(a - b) <= atol + rtol * sycl::fabs(b);
61 if (sycl::isnan(a) && sycl::isnan(b)) {
69inline bool isclose(
const std::complex<T> a,
70 const std::complex<T> b,
75 const bool a_finite = sycl::isfinite(a.real()) && sycl::isfinite(a.imag());
76 const bool b_finite = sycl::isfinite(b.real()) && sycl::isfinite(b.imag());
78 if (a_finite && b_finite) {
79 return exprm_ns::abs(exprm_ns::complex<T>(a - b)) <=
80 atol + rtol * exprm_ns::abs(exprm_ns::complex<T>(b));
83 if (sycl::isnan(a.real()) && sycl::isnan(a.imag()) &&
84 sycl::isnan(b.real()) && sycl::isnan(b.imag())) {
94 typename ThreeOffsets_IndexerT>
98 const T *a_ =
nullptr;
99 const T *b_ =
nullptr;
100 resTy *out_ =
nullptr;
101 const ThreeOffsets_IndexerT three_offsets_indexer_;
104 const bool equal_nan_;
110 const ThreeOffsets_IndexerT &inps_res_indexer,
113 const bool equal_nan)
114 : a_(a), b_(b), out_(out), three_offsets_indexer_(inps_res_indexer),
115 rtol_(rtol), atol_(atol), equal_nan_(equal_nan)
119 void operator()(sycl::id<1> wid)
const
121 const auto &three_offsets_ = three_offsets_indexer_(wid.get(0));
122 const dpctl::tensor::ssize_t &inp1_offset =
123 three_offsets_.get_first_offset();
124 const dpctl::tensor::ssize_t &inp2_offset =
125 three_offsets_.get_second_offset();
126 const dpctl::tensor::ssize_t &out_offset =
127 three_offsets_.get_third_offset();
130 isclose(a_[inp1_offset], b_[inp2_offset], rtol_, atol_, equal_nan_);
137 std::uint8_t vec_sz = 4u,
138 std::uint8_t n_vecs = 2u,
139 bool enable_sg_loadstore =
true>
143 const T *a_ =
nullptr;
144 const T *b_ =
nullptr;
145 resTy *out_ =
nullptr;
149 const bool equal_nan_;
155 const std::size_t n_elems,
158 const bool equal_nan)
159 : a_(a), b_(b), out_(out), nelems_(n_elems), rtol_(rtol), atol_(atol),
160 equal_nan_(equal_nan)
164 void operator()(sycl::nd_item<1> ndit)
const
166 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
170 using dpctl::tensor::type_utils::is_complex_v;
171 if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
172 auto sg = ndit.get_sub_group();
173 const std::uint16_t sgSize = sg.get_max_local_range()[0];
174 const std::size_t base =
175 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
176 sg.get_group_id()[0] * sgSize);
178 if (base + elems_per_wi * sgSize < nelems_) {
179 using dpctl::tensor::sycl_utils::sub_group_load;
180 using dpctl::tensor::sycl_utils::sub_group_store;
182 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
183 const std::size_t offset = base + it * sgSize;
184 auto a_multi_ptr = sycl::address_space_cast<
185 sycl::access::address_space::global_space,
186 sycl::access::decorated::yes>(&a_[offset]);
187 auto b_multi_ptr = sycl::address_space_cast<
188 sycl::access::address_space::global_space,
189 sycl::access::decorated::yes>(&b_[offset]);
190 auto out_multi_ptr = sycl::address_space_cast<
191 sycl::access::address_space::global_space,
192 sycl::access::decorated::yes>(&out_[offset]);
194 const sycl::vec<T, vec_sz> a_vec =
195 sub_group_load<vec_sz>(sg, a_multi_ptr);
196 const sycl::vec<T, vec_sz> b_vec =
197 sub_group_load<vec_sz>(sg, b_multi_ptr);
199 sycl::vec<resTy, vec_sz> res_vec;
201 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
202 res_vec[vec_id] = isclose(a_vec[vec_id], b_vec[vec_id],
203 rtol_, atol_, equal_nan_);
205 sub_group_store<vec_sz>(sg, res_vec, out_multi_ptr);
209 const std::size_t lane_id = sg.get_local_id()[0];
210 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
211 out_[k] = isclose(a_[k], b_[k], rtol_, atol_, equal_nan_);
216 const std::uint16_t sgSize =
217 ndit.get_sub_group().get_local_range()[0];
218 const std::size_t gid = ndit.get_global_linear_id();
219 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
221 const std::size_t start =
222 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
223 const std::size_t end = std::min(nelems_, start + elems_per_sg);
224 for (std::size_t offset = start; offset < end; offset += sgSize) {
226 isclose(a_[offset], b_[offset], rtol_, atol_, equal_nan_);
232template <
typename T,
typename scT>
234 isclose_strided_scalar_impl(sycl::queue &exec_q,
237 const dpctl::tensor::ssize_t *shape_strides,
240 const bool equal_nan,
242 const dpctl::tensor::ssize_t a_offset,
244 const dpctl::tensor::ssize_t b_offset,
246 const dpctl::tensor::ssize_t out_offset,
247 const std::vector<sycl::event> &depends)
249 dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
251 const T *a_tp =
reinterpret_cast<const T *
>(a_cp);
252 const T *b_tp =
reinterpret_cast<const T *
>(b_cp);
255 resTy *out_tp =
reinterpret_cast<resTy *
>(out_cp);
258 typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
259 const IndexerT indexer{nd, a_offset, b_offset, out_offset, shape_strides};
261 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
262 cgh.depends_on(depends);
265 IsCloseStridedScalarFunctor<T, scT, resTy, IndexerT>;
266 cgh.parallel_for<IsCloseFunc>(
268 IsCloseFunc(a_tp, b_tp, out_tp, indexer, rtol, atol, equal_nan));
275 std::uint8_t vec_sz = 4u,
276 std::uint8_t n_vecs = 2u>
278 isclose_contig_scalar_impl(sycl::queue &exec_q,
282 const bool equal_nan,
286 const std::vector<sycl::event> &depends = {})
288 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
289 const std::size_t n_work_items_needed = nelems / elems_per_wi;
290 const std::size_t empirical_threshold = std::size_t(1) << 21;
291 const std::size_t lws = (n_work_items_needed <= empirical_threshold)
295 const std::size_t n_groups =
296 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
297 const auto gws_range = sycl::range<1>(n_groups * lws);
298 const auto lws_range = sycl::range<1>(lws);
300 const T *a_tp =
reinterpret_cast<const T *
>(a_cp);
301 const T *b_tp =
reinterpret_cast<const T *
>(b_cp);
304 resTy *out_tp =
reinterpret_cast<resTy *
>(out_cp);
306 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
307 cgh.depends_on(depends);
309 using dpctl::tensor::kernels::alignment_utils::is_aligned;
310 using dpctl::tensor::kernels::alignment_utils::required_alignment;
311 if (is_aligned<required_alignment>(a_tp) &&
312 is_aligned<required_alignment>(b_tp) &&
313 is_aligned<required_alignment>(out_tp)) {
314 constexpr bool enable_sg_loadstore =
true;
316 IsCloseContigScalarFunctor<T, scT, resTy, vec_sz, n_vecs,
317 enable_sg_loadstore>;
319 cgh.parallel_for<IsCloseFunc>(
320 sycl::nd_range<1>(gws_range, lws_range),
321 IsCloseFunc(a_tp, b_tp, out_tp, nelems, rtol, atol, equal_nan));
324 constexpr bool disable_sg_loadstore =
false;
326 IsCloseContigScalarFunctor<T, scT, resTy, vec_sz, n_vecs,
327 disable_sg_loadstore>;
329 cgh.parallel_for<IsCloseFunc>(
330 sycl::nd_range<1>(gws_range, lws_range),
331 IsCloseFunc(a_tp, b_tp, out_tp, nelems, rtol, atol, equal_nan));