34#include <sycl/sycl.hpp> 
   36#include "kernels/dpctl_tensor_types.hpp" 
   37#include "utils/indexing_utils.hpp" 
   38#include "utils/offset_utils.hpp" 
   39#include "utils/strided_iters.hpp" 
   40#include "utils/type_utils.hpp" 
   42namespace dpnp::extensions::indexing::strides_detail
 
   48                            dpctl::tensor::ssize_t 
const *_offsets,
 
   49                            dpctl::tensor::ssize_t 
const *_shape,
 
   50                            dpctl::tensor::ssize_t 
const *_strides)
 
   51        : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
 
   56    template <
typename nT>
 
   57    size_t operator()(dpctl::tensor::ssize_t gid, nT n)
 const 
   59        dpctl::tensor::ssize_t relative_offset(0);
 
   60        _ind.get_displacement<
const dpctl::tensor::ssize_t *,
 
   61                              const dpctl::tensor::ssize_t *>(
 
   62            gid, shape, strides + (n * nd), relative_offset);
 
   64        return relative_offset + offsets[n];
 
   68    dpctl::tensor::strides::CIndexer_vector<dpctl::tensor::ssize_t> _ind;
 
   71    dpctl::tensor::ssize_t 
const *offsets;
 
   72    dpctl::tensor::ssize_t 
const *shape;
 
   73    dpctl::tensor::ssize_t 
const *strides;
 
 
   76static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
 
   80namespace dpnp::extensions::indexing::kernels
 
   83template <
typename ProjectorT,
 
   84          typename IndOutIndexerT,
 
   85          typename ChoicesIndexerT,
 
   91    const IndT *ind = 
nullptr;
 
   93    char **chcs = 
nullptr;
 
   94    dpctl::tensor::ssize_t n_chcs;
 
   95    const IndOutIndexerT ind_out_indexer;
 
   96    const ChoicesIndexerT chcs_indexer;
 
  102                  dpctl::tensor::ssize_t n_chcs_,
 
  103                  const IndOutIndexerT &ind_out_indexer_,
 
  104                  const ChoicesIndexerT &chcs_indexer_)
 
  105        : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
 
  106          ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
 
  110    void operator()(sycl::id<1> 
id)
 const 
  112        const ProjectorT proj{};
 
  114        dpctl::tensor::ssize_t i = 
id[0];
 
  116        auto ind_dst_offsets = ind_out_indexer(i);
 
  117        dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
 
  118        dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
 
  120        IndT chc_idx = ind[ind_offset];
 
  122        dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
 
  124        dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
 
  126        T *chc = 
reinterpret_cast<T *
>(chcs[projected_idx]);
 
  128        dst[dst_offset] = chc[chc_offset];
 
 
  132typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
 
  134                                       dpctl::tensor::ssize_t,
 
  136                                       const dpctl::tensor::ssize_t *,
 
  140                                       dpctl::tensor::ssize_t,
 
  141                                       dpctl::tensor::ssize_t,
 
  142                                       const dpctl::tensor::ssize_t *,
 
  143                                       const std::vector<sycl::event> &);
 
  145template <
typename ProjectorT, 
typename indTy, 
typename Ty>
 
  146sycl::event choose_impl(sycl::queue &q,
 
  148                        dpctl::tensor::ssize_t n_chcs,
 
  150                        const dpctl::tensor::ssize_t *shape_and_strides,
 
  154                        dpctl::tensor::ssize_t ind_offset,
 
  155                        dpctl::tensor::ssize_t dst_offset,
 
  156                        const dpctl::tensor::ssize_t *chc_offsets,
 
  157                        const std::vector<sycl::event> &depends)
 
  159    dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
 
  161    const indTy *ind_tp = 
reinterpret_cast<const indTy *
>(ind_cp);
 
  162    Ty *dst_tp = 
reinterpret_cast<Ty *
>(dst_cp);
 
  164    sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
 
  165        cgh.depends_on(depends);
 
  167        using InOutIndexerT =
 
  168            dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
 
  169        const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
 
  173        const NthChoiceIndexerT choices_indexer{
 
  174            nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
 
  176        using ChooseFunc = ChooseFunctor<ProjectorT, InOutIndexerT,
 
  177                                         NthChoiceIndexerT, indTy, Ty>;
 
  179        cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
 
  180                                     ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,