32#include <sycl/sycl.hpp> 
   34#include "kernels/alignment.hpp" 
   35#include "kernels/dpctl_tensor_types.hpp" 
   36#include "utils/offset_utils.hpp" 
   37#include "utils/sycl_utils.hpp" 
   38#include "utils/type_utils.hpp" 
   40namespace dpnp::kernels::nan_to_num
 
   44inline T to_num(
const T v, 
const T nan, 
const T posinf, 
const T neginf)
 
   46    return (sycl::isnan(v))   ? nan
 
   47           : (sycl::isinf(v)) ? (v > 0) ? posinf : neginf
 
   51template <
typename T, 
typename scT, 
typename InOutIndexerT>
 
   55    const T *inp_ = 
nullptr;
 
   57    const InOutIndexerT inp_out_indexer_;
 
   65                    const InOutIndexerT &inp_out_indexer,
 
   69        : inp_(inp), out_(out), inp_out_indexer_(inp_out_indexer), nan_(nan),
 
   70          posinf_(posinf), neginf_(neginf)
 
   74    void operator()(sycl::id<1> wid)
 const 
   76        const auto &offsets_ = inp_out_indexer_(wid.get(0));
 
   77        const dpctl::tensor::ssize_t &inp_offset = offsets_.get_first_offset();
 
   78        const dpctl::tensor::ssize_t &out_offset = offsets_.get_second_offset();
 
   80        using dpctl::tensor::type_utils::is_complex_v;
 
   81        if constexpr (is_complex_v<T>) {
 
   82            using realT = 
typename T::value_type;
 
   83            static_assert(std::is_same_v<realT, scT>);
 
   84            T z = inp_[inp_offset];
 
   85            realT x = to_num(z.real(), nan_, posinf_, neginf_);
 
   86            realT y = to_num(z.imag(), nan_, posinf_, neginf_);
 
   87            out_[out_offset] = T{x, y};
 
   90            out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_);
 
 
   97          std::uint8_t vec_sz = 4u,
 
   98          std::uint8_t n_vecs = 2u,
 
   99          bool enable_sg_loadstore = 
true>
 
  103    const T *in_ = 
nullptr;
 
  113                          const std::size_t n_elems,
 
  117        : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf),
 
  122    void operator()(sycl::nd_item<1> ndit)
 const 
  124        constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
 
  128        using dpctl::tensor::type_utils::is_complex_v;
 
  129        if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
 
  130            auto sg = ndit.get_sub_group();
 
  131            const std::uint16_t sgSize = sg.get_max_local_range()[0];
 
  132            const std::size_t base =
 
  133                elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
 
  134                                sg.get_group_id()[0] * sgSize);
 
  136            if (base + elems_per_wi * sgSize < nelems_) {
 
  137                using dpctl::tensor::sycl_utils::sub_group_load;
 
  138                using dpctl::tensor::sycl_utils::sub_group_store;
 
  140                for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
 
  141                    const std::size_t offset = base + it * sgSize;
 
  142                    auto in_multi_ptr = sycl::address_space_cast<
 
  143                        sycl::access::address_space::global_space,
 
  144                        sycl::access::decorated::yes>(&in_[offset]);
 
  145                    auto out_multi_ptr = sycl::address_space_cast<
 
  146                        sycl::access::address_space::global_space,
 
  147                        sycl::access::decorated::yes>(&out_[offset]);
 
  149                    sycl::vec<T, vec_sz> arg_vec =
 
  150                        sub_group_load<vec_sz>(sg, in_multi_ptr);
 
  152                    for (std::uint32_t k = 0; k < vec_sz; ++k) {
 
  153                        arg_vec[k] = to_num(arg_vec[k], nan_, posinf_, neginf_);
 
  155                    sub_group_store<vec_sz>(sg, arg_vec, out_multi_ptr);
 
  159                const std::size_t lane_id = sg.get_local_id()[0];
 
  160                for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
 
  161                    out_[k] = to_num(in_[k], nan_, posinf_, neginf_);
 
  166            const std::uint16_t sgSize =
 
  167                ndit.get_sub_group().get_local_range()[0];
 
  168            const std::size_t gid = ndit.get_global_linear_id();
 
  169            const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
 
  171            const std::size_t start =
 
  172                (gid / sgSize) * (elems_per_sg - sgSize) + gid;
 
  173            const std::size_t end = std::min(nelems_, start + elems_per_sg);
 
  174            for (std::size_t offset = start; offset < end; offset += sgSize) {
 
  175                if constexpr (is_complex_v<T>) {
 
  176                    using realT = 
typename T::value_type;
 
  177                    static_assert(std::is_same_v<realT, scT>);
 
  180                    realT x = to_num(z.real(), nan_, posinf_, neginf_);
 
  181                    realT y = to_num(z.imag(), nan_, posinf_, neginf_);
 
  182                    out_[offset] = T{x, y};
 
  185                    out_[offset] = to_num(in_[offset], nan_, posinf_, neginf_);
 
 
  192template <
typename T, 
typename scT>
 
  193sycl::event nan_to_num_strided_impl(sycl::queue &q,
 
  196                                    const dpctl::tensor::ssize_t *shape_strides,
 
  201                                    const dpctl::tensor::ssize_t in_offset,
 
  203                                    const dpctl::tensor::ssize_t out_offset,
 
  204                                    const std::vector<sycl::event> &depends)
 
  206    dpctl::tensor::type_utils::validate_type_for_device<T>(q);
 
  208    const T *in_tp = 
reinterpret_cast<const T *
>(in_cp);
 
  209    T *out_tp = 
reinterpret_cast<T *
>(out_cp);
 
  211    using InOutIndexerT =
 
  212        typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
 
  213    const InOutIndexerT indexer{nd, in_offset, out_offset, shape_strides};
 
  215    sycl::event comp_ev = q.submit([&](sycl::handler &cgh) {
 
  216        cgh.depends_on(depends);
 
  218        using NanToNumFunc = NanToNumFunctor<T, scT, InOutIndexerT>;
 
  219        cgh.parallel_for<NanToNumFunc>(
 
  221            NanToNumFunc(in_tp, out_tp, indexer, nan, posinf, neginf));
 
  228          std::uint8_t vec_sz = 4u,
 
  229          std::uint8_t n_vecs = 2u>
 
  230sycl::event nan_to_num_contig_impl(sycl::queue &exec_q,
 
  237                                   const std::vector<sycl::event> &depends = {})
 
  239    constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
 
  240    const std::size_t n_work_items_needed = nelems / elems_per_wi;
 
  241    const std::size_t empirical_threshold = std::size_t(1) << 21;
 
  242    const std::size_t lws = (n_work_items_needed <= empirical_threshold)
 
  246    const std::size_t n_groups =
 
  247        ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
 
  248    const auto gws_range = sycl::range<1>(n_groups * lws);
 
  249    const auto lws_range = sycl::range<1>(lws);
 
  251    const T *in_tp = 
reinterpret_cast<const T *
>(in_cp);
 
  252    T *out_tp = 
reinterpret_cast<T *
>(out_cp);
 
  254    sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
 
  255        cgh.depends_on(depends);
 
  257        using dpctl::tensor::kernels::alignment_utils::is_aligned;
 
  258        using dpctl::tensor::kernels::alignment_utils::required_alignment;
 
  259        if (is_aligned<required_alignment>(in_tp) &&
 
  260            is_aligned<required_alignment>(out_tp))
 
  262            constexpr bool enable_sg_loadstore = 
true;
 
  263            using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
 
  264                                                       enable_sg_loadstore>;
 
  266            cgh.parallel_for<NanToNumFunc>(
 
  267                sycl::nd_range<1>(gws_range, lws_range),
 
  268                NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));
 
  271            constexpr bool disable_sg_loadstore = 
false;
 
  272            using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
 
  273                                                       disable_sg_loadstore>;
 
  275            cgh.parallel_for<NanToNumFunc>(
 
  276                sycl::nd_range<1>(gws_range, lws_range),
 
  277                NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));