34#include <sycl/sycl.hpp>
36#include "kernels/dpctl_tensor_types.hpp"
37#include "utils/indexing_utils.hpp"
38#include "utils/offset_utils.hpp"
39#include "utils/strided_iters.hpp"
40#include "utils/type_utils.hpp"
42namespace dpnp::extensions::indexing::strides_detail
48 dpctl::tensor::ssize_t
const *_offsets,
49 dpctl::tensor::ssize_t
const *_shape,
50 dpctl::tensor::ssize_t
const *_strides)
51 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
56 template <
typename nT>
57 size_t operator()(dpctl::tensor::ssize_t gid, nT n)
const
59 dpctl::tensor::ssize_t relative_offset(0);
60 _ind.get_displacement<
const dpctl::tensor::ssize_t *,
61 const dpctl::tensor::ssize_t *>(
62 gid, shape, strides + (n * nd), relative_offset);
64 return relative_offset + offsets[n];
68 dpctl::tensor::strides::CIndexer_vector<dpctl::tensor::ssize_t> _ind;
71 dpctl::tensor::ssize_t
const *offsets;
72 dpctl::tensor::ssize_t
const *shape;
73 dpctl::tensor::ssize_t
const *strides;
76static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
80namespace dpnp::extensions::indexing::kernels
83template <
typename ProjectorT,
84 typename IndOutIndexerT,
85 typename ChoicesIndexerT,
91 const IndT *ind =
nullptr;
93 char **chcs =
nullptr;
94 dpctl::tensor::ssize_t n_chcs;
95 const IndOutIndexerT ind_out_indexer;
96 const ChoicesIndexerT chcs_indexer;
102 dpctl::tensor::ssize_t n_chcs_,
103 const IndOutIndexerT &ind_out_indexer_,
104 const ChoicesIndexerT &chcs_indexer_)
105 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
106 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
110 void operator()(sycl::id<1>
id)
const
112 const ProjectorT proj{};
114 dpctl::tensor::ssize_t i =
id[0];
116 auto ind_dst_offsets = ind_out_indexer(i);
117 dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
118 dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
120 IndT chc_idx = ind[ind_offset];
122 dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
124 dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
126 T *chc =
reinterpret_cast<T *
>(chcs[projected_idx]);
128 dst[dst_offset] = chc[chc_offset];
132typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
134 dpctl::tensor::ssize_t,
136 const dpctl::tensor::ssize_t *,
140 dpctl::tensor::ssize_t,
141 dpctl::tensor::ssize_t,
142 const dpctl::tensor::ssize_t *,
143 const std::vector<sycl::event> &);
145template <
typename ProjectorT,
typename indTy,
typename Ty>
146sycl::event choose_impl(sycl::queue &q,
148 dpctl::tensor::ssize_t n_chcs,
150 const dpctl::tensor::ssize_t *shape_and_strides,
154 dpctl::tensor::ssize_t ind_offset,
155 dpctl::tensor::ssize_t dst_offset,
156 const dpctl::tensor::ssize_t *chc_offsets,
157 const std::vector<sycl::event> &depends)
159 dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
161 const indTy *ind_tp =
reinterpret_cast<const indTy *
>(ind_cp);
162 Ty *dst_tp =
reinterpret_cast<Ty *
>(dst_cp);
164 sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
165 cgh.depends_on(depends);
167 using InOutIndexerT =
168 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
169 const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
173 const NthChoiceIndexerT choices_indexer{
174 nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
176 using ChooseFunc = ChooseFunctor<ProjectorT, InOutIndexerT,
177 NthChoiceIndexerT, indTy, Ty>;
179 cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
180 ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,