35#include <sycl/sycl.hpp>
37#include "kernels/alignment.hpp"
38#include "kernels/dpctl_tensor_types.hpp"
39#include "utils/offset_utils.hpp"
40#include "utils/sycl_utils.hpp"
41#include "utils/type_utils.hpp"
43namespace dpnp::kernels::nan_to_num
47inline T to_num(
const T v,
const T nan,
const T posinf,
const T neginf)
49 return (sycl::isnan(v)) ? nan
50 : (sycl::isinf(v)) ? (v > 0) ? posinf : neginf
54template <
typename T,
typename scT,
typename InOutIndexerT>
58 const T *inp_ =
nullptr;
60 const InOutIndexerT inp_out_indexer_;
68 const InOutIndexerT &inp_out_indexer,
72 : inp_(inp), out_(out), inp_out_indexer_(inp_out_indexer), nan_(nan),
73 posinf_(posinf), neginf_(neginf)
77 void operator()(sycl::id<1> wid)
const
79 const auto &offsets_ = inp_out_indexer_(wid.get(0));
80 const dpctl::tensor::ssize_t &inp_offset = offsets_.get_first_offset();
81 const dpctl::tensor::ssize_t &out_offset = offsets_.get_second_offset();
83 using dpctl::tensor::type_utils::is_complex_v;
84 if constexpr (is_complex_v<T>) {
85 using realT =
typename T::value_type;
86 static_assert(std::is_same_v<realT, scT>);
87 T z = inp_[inp_offset];
88 realT x = to_num(z.real(), nan_, posinf_, neginf_);
89 realT y = to_num(z.imag(), nan_, posinf_, neginf_);
90 out_[out_offset] = T{x, y};
93 out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_);
100 std::uint8_t vec_sz = 4u,
101 std::uint8_t n_vecs = 2u,
102 bool enable_sg_loadstore =
true>
106 const T *in_ =
nullptr;
116 const std::size_t n_elems,
120 : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf),
125 void operator()(sycl::nd_item<1> ndit)
const
127 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
131 using dpctl::tensor::type_utils::is_complex_v;
132 if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
133 auto sg = ndit.get_sub_group();
134 const std::uint16_t sgSize = sg.get_max_local_range()[0];
135 const std::size_t base =
136 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
137 sg.get_group_id()[0] * sgSize);
139 if (base + elems_per_wi * sgSize < nelems_) {
140 using dpctl::tensor::sycl_utils::sub_group_load;
141 using dpctl::tensor::sycl_utils::sub_group_store;
143 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
144 const std::size_t offset = base + it * sgSize;
145 auto in_multi_ptr = sycl::address_space_cast<
146 sycl::access::address_space::global_space,
147 sycl::access::decorated::yes>(&in_[offset]);
148 auto out_multi_ptr = sycl::address_space_cast<
149 sycl::access::address_space::global_space,
150 sycl::access::decorated::yes>(&out_[offset]);
152 sycl::vec<T, vec_sz> arg_vec =
153 sub_group_load<vec_sz>(sg, in_multi_ptr);
155 for (std::uint32_t k = 0; k < vec_sz; ++k) {
156 arg_vec[k] = to_num(arg_vec[k], nan_, posinf_, neginf_);
158 sub_group_store<vec_sz>(sg, arg_vec, out_multi_ptr);
162 const std::size_t lane_id = sg.get_local_id()[0];
163 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
164 out_[k] = to_num(in_[k], nan_, posinf_, neginf_);
169 const std::uint16_t sgSize =
170 ndit.get_sub_group().get_local_range()[0];
171 const std::size_t gid = ndit.get_global_linear_id();
172 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
174 const std::size_t start =
175 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
176 const std::size_t end = std::min(nelems_, start + elems_per_sg);
177 for (std::size_t offset = start; offset < end; offset += sgSize) {
178 if constexpr (is_complex_v<T>) {
179 using realT =
typename T::value_type;
180 static_assert(std::is_same_v<realT, scT>);
183 realT x = to_num(z.real(), nan_, posinf_, neginf_);
184 realT y = to_num(z.imag(), nan_, posinf_, neginf_);
185 out_[offset] = T{x, y};
188 out_[offset] = to_num(in_[offset], nan_, posinf_, neginf_);
195template <
typename T,
typename scT>
196sycl::event nan_to_num_strided_impl(sycl::queue &q,
199 const dpctl::tensor::ssize_t *shape_strides,
204 const dpctl::tensor::ssize_t in_offset,
206 const dpctl::tensor::ssize_t out_offset,
207 const std::vector<sycl::event> &depends)
209 dpctl::tensor::type_utils::validate_type_for_device<T>(q);
211 const T *in_tp =
reinterpret_cast<const T *
>(in_cp);
212 T *out_tp =
reinterpret_cast<T *
>(out_cp);
214 using InOutIndexerT =
215 typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
216 const InOutIndexerT indexer{nd, in_offset, out_offset, shape_strides};
218 sycl::event comp_ev = q.submit([&](sycl::handler &cgh) {
219 cgh.depends_on(depends);
221 using NanToNumFunc = NanToNumFunctor<T, scT, InOutIndexerT>;
222 cgh.parallel_for<NanToNumFunc>(
224 NanToNumFunc(in_tp, out_tp, indexer, nan, posinf, neginf));
231 std::uint8_t vec_sz = 4u,
232 std::uint8_t n_vecs = 2u>
233sycl::event nan_to_num_contig_impl(sycl::queue &exec_q,
240 const std::vector<sycl::event> &depends = {})
242 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
243 const std::size_t n_work_items_needed = nelems / elems_per_wi;
244 const std::size_t empirical_threshold = std::size_t(1) << 21;
245 const std::size_t lws = (n_work_items_needed <= empirical_threshold)
249 const std::size_t n_groups =
250 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
251 const auto gws_range = sycl::range<1>(n_groups * lws);
252 const auto lws_range = sycl::range<1>(lws);
254 const T *in_tp =
reinterpret_cast<const T *
>(in_cp);
255 T *out_tp =
reinterpret_cast<T *
>(out_cp);
257 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
258 cgh.depends_on(depends);
260 using dpctl::tensor::kernels::alignment_utils::is_aligned;
261 using dpctl::tensor::kernels::alignment_utils::required_alignment;
262 if (is_aligned<required_alignment>(in_tp) &&
263 is_aligned<required_alignment>(out_tp))
265 constexpr bool enable_sg_loadstore =
true;
266 using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
267 enable_sg_loadstore>;
269 cgh.parallel_for<NanToNumFunc>(
270 sycl::nd_range<1>(gws_range, lws_range),
271 NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));
274 constexpr bool disable_sg_loadstore =
false;
275 using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
276 disable_sg_loadstore>;
278 cgh.parallel_for<NanToNumFunc>(
279 sycl::nd_range<1>(gws_range, lws_range),
280 NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));