37#include <sycl/sycl.hpp>
39#include "kernels/dpctl_tensor_types.hpp"
40#include "utils/indexing_utils.hpp"
41#include "utils/offset_utils.hpp"
42#include "utils/strided_iters.hpp"
43#include "utils/type_utils.hpp"
45namespace dpnp::extensions::indexing::strides_detail
51 dpctl::tensor::ssize_t
const *_offsets,
52 dpctl::tensor::ssize_t
const *_shape,
53 dpctl::tensor::ssize_t
const *_strides)
54 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
59 template <
typename nT>
60 size_t operator()(dpctl::tensor::ssize_t gid, nT n)
const
62 dpctl::tensor::ssize_t relative_offset(0);
63 _ind.get_displacement<
const dpctl::tensor::ssize_t *,
64 const dpctl::tensor::ssize_t *>(
65 gid, shape, strides + (n * nd), relative_offset);
67 return relative_offset + offsets[n];
71 dpctl::tensor::strides::CIndexer_vector<dpctl::tensor::ssize_t> _ind;
74 dpctl::tensor::ssize_t
const *offsets;
75 dpctl::tensor::ssize_t
const *shape;
76 dpctl::tensor::ssize_t
const *strides;
79static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
83namespace dpnp::extensions::indexing::kernels
86template <
typename ProjectorT,
87 typename IndOutIndexerT,
88 typename ChoicesIndexerT,
94 const IndT *ind =
nullptr;
96 char **chcs =
nullptr;
97 dpctl::tensor::ssize_t n_chcs;
98 const IndOutIndexerT ind_out_indexer;
99 const ChoicesIndexerT chcs_indexer;
105 dpctl::tensor::ssize_t n_chcs_,
106 const IndOutIndexerT &ind_out_indexer_,
107 const ChoicesIndexerT &chcs_indexer_)
108 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
109 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
113 void operator()(sycl::id<1>
id)
const
115 const ProjectorT proj{};
117 dpctl::tensor::ssize_t i =
id[0];
119 auto ind_dst_offsets = ind_out_indexer(i);
120 dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
121 dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
123 IndT chc_idx = ind[ind_offset];
125 dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
127 dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
129 T *chc =
reinterpret_cast<T *
>(chcs[projected_idx]);
131 dst[dst_offset] = chc[chc_offset];
135typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
137 dpctl::tensor::ssize_t,
139 const dpctl::tensor::ssize_t *,
143 dpctl::tensor::ssize_t,
144 dpctl::tensor::ssize_t,
145 const dpctl::tensor::ssize_t *,
146 const std::vector<sycl::event> &);
148template <
typename ProjectorT,
typename indTy,
typename Ty>
149sycl::event choose_impl(sycl::queue &q,
151 dpctl::tensor::ssize_t n_chcs,
153 const dpctl::tensor::ssize_t *shape_and_strides,
157 dpctl::tensor::ssize_t ind_offset,
158 dpctl::tensor::ssize_t dst_offset,
159 const dpctl::tensor::ssize_t *chc_offsets,
160 const std::vector<sycl::event> &depends)
162 dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
164 const indTy *ind_tp =
reinterpret_cast<const indTy *
>(ind_cp);
165 Ty *dst_tp =
reinterpret_cast<Ty *
>(dst_cp);
167 sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
168 cgh.depends_on(depends);
170 using InOutIndexerT =
171 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
172 const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
176 const NthChoiceIndexerT choices_indexer{
177 nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
179 using ChooseFunc = ChooseFunctor<ProjectorT, InOutIndexerT,
180 NthChoiceIndexerT, indTy, Ty>;
182 cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
183 ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,