33#include <sycl/sycl.hpp> 
   35#include "kernels/alignment.hpp" 
   36#include "kernels/dpctl_tensor_types.hpp" 
   37#include "kernels/elementwise_functions/sycl_complex.hpp" 
   38#include "utils/offset_utils.hpp" 
   39#include "utils/sycl_utils.hpp" 
   40#include "utils/type_utils.hpp" 
   42namespace dpnp::kernels::isclose
 
   46inline bool isclose(
const T a,
 
   52    static_assert(std::is_floating_point_v<T> || std::is_same_v<T, sycl::half>);
 
   54    if (sycl::isfinite(a) && sycl::isfinite(b)) {
 
   55        return sycl::fabs(a - b) <= atol + rtol * sycl::fabs(b);
 
   58    if (sycl::isnan(a) && sycl::isnan(b)) {
 
   66inline bool isclose(
const std::complex<T> a,
 
   67                    const std::complex<T> b,
 
   72    const bool a_finite = sycl::isfinite(a.real()) && sycl::isfinite(a.imag());
 
   73    const bool b_finite = sycl::isfinite(b.real()) && sycl::isfinite(b.imag());
 
   75    if (a_finite && b_finite) {
 
   76        return exprm_ns::abs(exprm_ns::complex<T>(a - b)) <=
 
   77               atol + rtol * exprm_ns::abs(exprm_ns::complex<T>(b));
 
   80    if (sycl::isnan(a.real()) && sycl::isnan(a.imag()) &&
 
   81        sycl::isnan(b.real()) && sycl::isnan(b.imag()))
 
   92          typename ThreeOffsets_IndexerT>
 
   96    const T *a_ = 
nullptr;
 
   97    const T *b_ = 
nullptr;
 
   98    resTy *out_ = 
nullptr;
 
   99    const ThreeOffsets_IndexerT three_offsets_indexer_;
 
  102    const bool equal_nan_;
 
  108                                const ThreeOffsets_IndexerT &inps_res_indexer,
 
  111                                const bool equal_nan)
 
  112        : a_(a), b_(b), out_(out), three_offsets_indexer_(inps_res_indexer),
 
  113          rtol_(rtol), atol_(atol), equal_nan_(equal_nan)
 
  117    void operator()(sycl::id<1> wid)
 const 
  119        const auto &three_offsets_ = three_offsets_indexer_(wid.get(0));
 
  120        const dpctl::tensor::ssize_t &inp1_offset =
 
  121            three_offsets_.get_first_offset();
 
  122        const dpctl::tensor::ssize_t &inp2_offset =
 
  123            three_offsets_.get_second_offset();
 
  124        const dpctl::tensor::ssize_t &out_offset =
 
  125            three_offsets_.get_third_offset();
 
  128            isclose(a_[inp1_offset], b_[inp2_offset], rtol_, atol_, equal_nan_);
 
 
  135          std::uint8_t vec_sz = 4u,
 
  136          std::uint8_t n_vecs = 2u,
 
  137          bool enable_sg_loadstore = 
true>
 
  141    const T *a_ = 
nullptr;
 
  142    const T *b_ = 
nullptr;
 
  143    resTy *out_ = 
nullptr;
 
  147    const bool equal_nan_;
 
  153                               const std::size_t n_elems,
 
  156                               const bool equal_nan)
 
  157        : a_(a), b_(b), out_(out), nelems_(n_elems), rtol_(rtol), atol_(atol),
 
  158          equal_nan_(equal_nan)
 
  162    void operator()(sycl::nd_item<1> ndit)
 const 
  164        constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
 
  168        using dpctl::tensor::type_utils::is_complex_v;
 
  169        if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
 
  170            auto sg = ndit.get_sub_group();
 
  171            const std::uint16_t sgSize = sg.get_max_local_range()[0];
 
  172            const std::size_t base =
 
  173                elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
 
  174                                sg.get_group_id()[0] * sgSize);
 
  176            if (base + elems_per_wi * sgSize < nelems_) {
 
  177                using dpctl::tensor::sycl_utils::sub_group_load;
 
  178                using dpctl::tensor::sycl_utils::sub_group_store;
 
  180                for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
 
  181                    const std::size_t offset = base + it * sgSize;
 
  182                    auto a_multi_ptr = sycl::address_space_cast<
 
  183                        sycl::access::address_space::global_space,
 
  184                        sycl::access::decorated::yes>(&a_[offset]);
 
  185                    auto b_multi_ptr = sycl::address_space_cast<
 
  186                        sycl::access::address_space::global_space,
 
  187                        sycl::access::decorated::yes>(&b_[offset]);
 
  188                    auto out_multi_ptr = sycl::address_space_cast<
 
  189                        sycl::access::address_space::global_space,
 
  190                        sycl::access::decorated::yes>(&out_[offset]);
 
  192                    const sycl::vec<T, vec_sz> a_vec =
 
  193                        sub_group_load<vec_sz>(sg, a_multi_ptr);
 
  194                    const sycl::vec<T, vec_sz> b_vec =
 
  195                        sub_group_load<vec_sz>(sg, b_multi_ptr);
 
  197                    sycl::vec<resTy, vec_sz> res_vec;
 
  199                    for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
 
  200                        res_vec[vec_id] = isclose(a_vec[vec_id], b_vec[vec_id],
 
  201                                                  rtol_, atol_, equal_nan_);
 
  203                    sub_group_store<vec_sz>(sg, res_vec, out_multi_ptr);
 
  207                const std::size_t lane_id = sg.get_local_id()[0];
 
  208                for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
 
  209                    out_[k] = isclose(a_[k], b_[k], rtol_, atol_, equal_nan_);
 
  214            const std::uint16_t sgSize =
 
  215                ndit.get_sub_group().get_local_range()[0];
 
  216            const std::size_t gid = ndit.get_global_linear_id();
 
  217            const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
 
  219            const std::size_t start =
 
  220                (gid / sgSize) * (elems_per_sg - sgSize) + gid;
 
  221            const std::size_t end = std::min(nelems_, start + elems_per_sg);
 
  222            for (std::size_t offset = start; offset < end; offset += sgSize) {
 
  224                    isclose(a_[offset], b_[offset], rtol_, atol_, equal_nan_);
 
 
  230template <
typename T, 
typename scT>
 
  232    isclose_strided_scalar_impl(sycl::queue &exec_q,
 
  235                                const dpctl::tensor::ssize_t *shape_strides,
 
  238                                const bool equal_nan,
 
  240                                const dpctl::tensor::ssize_t a_offset,
 
  242                                const dpctl::tensor::ssize_t b_offset,
 
  244                                const dpctl::tensor::ssize_t out_offset,
 
  245                                const std::vector<sycl::event> &depends)
 
  247    dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
 
  249    const T *a_tp = 
reinterpret_cast<const T *
>(a_cp);
 
  250    const T *b_tp = 
reinterpret_cast<const T *
>(b_cp);
 
  253    resTy *out_tp = 
reinterpret_cast<resTy *
>(out_cp);
 
  256        typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
 
  257    const IndexerT indexer{nd, a_offset, b_offset, out_offset, shape_strides};
 
  259    sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
 
  260        cgh.depends_on(depends);
 
  263            IsCloseStridedScalarFunctor<T, scT, resTy, IndexerT>;
 
  264        cgh.parallel_for<IsCloseFunc>(
 
  266            IsCloseFunc(a_tp, b_tp, out_tp, indexer, rtol, atol, equal_nan));
 
  273          std::uint8_t vec_sz = 4u,
 
  274          std::uint8_t n_vecs = 2u>
 
  276    isclose_contig_scalar_impl(sycl::queue &exec_q,
 
  280                               const bool equal_nan,
 
  284                               const std::vector<sycl::event> &depends = {})
 
  286    constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
 
  287    const std::size_t n_work_items_needed = nelems / elems_per_wi;
 
  288    const std::size_t empirical_threshold = std::size_t(1) << 21;
 
  289    const std::size_t lws = (n_work_items_needed <= empirical_threshold)
 
  293    const std::size_t n_groups =
 
  294        ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
 
  295    const auto gws_range = sycl::range<1>(n_groups * lws);
 
  296    const auto lws_range = sycl::range<1>(lws);
 
  298    const T *a_tp = 
reinterpret_cast<const T *
>(a_cp);
 
  299    const T *b_tp = 
reinterpret_cast<const T *
>(b_cp);
 
  302    resTy *out_tp = 
reinterpret_cast<resTy *
>(out_cp);
 
  304    sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
 
  305        cgh.depends_on(depends);
 
  307        using dpctl::tensor::kernels::alignment_utils::is_aligned;
 
  308        using dpctl::tensor::kernels::alignment_utils::required_alignment;
 
  309        if (is_aligned<required_alignment>(a_tp) &&
 
  310            is_aligned<required_alignment>(b_tp) &&
 
  311            is_aligned<required_alignment>(out_tp))
 
  313            constexpr bool enable_sg_loadstore = 
true;
 
  315                IsCloseContigScalarFunctor<T, scT, resTy, vec_sz, n_vecs,
 
  316                                           enable_sg_loadstore>;
 
  318            cgh.parallel_for<IsCloseFunc>(
 
  319                sycl::nd_range<1>(gws_range, lws_range),
 
  320                IsCloseFunc(a_tp, b_tp, out_tp, nelems, rtol, atol, equal_nan));
 
  323            constexpr bool disable_sg_loadstore = 
false;
 
  325                IsCloseContigScalarFunctor<T, scT, resTy, vec_sz, n_vecs,
 
  326                                           disable_sg_loadstore>;
 
  328            cgh.parallel_for<IsCloseFunc>(
 
  329                sycl::nd_range<1>(gws_range, lws_range),
 
  330                IsCloseFunc(a_tp, b_tp, out_tp, nelems, rtol, atol, equal_nan));