DPNP C++ backend kernel library 0.18.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
choose_kernel.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#pragma once
27
28#include <algorithm>
29#include <complex>
30#include <cstdint>
31#include <limits>
32#include <type_traits>
33
34#include <sycl/sycl.hpp>
35
36#include "kernels/dpctl_tensor_types.hpp"
37#include "utils/indexing_utils.hpp"
38#include "utils/offset_utils.hpp"
39#include "utils/strided_iters.hpp"
40#include "utils/type_utils.hpp"
41
42namespace dpnp::extensions::indexing::strides_detail
43{
44
46{
47 NthStrideOffsetUnpacked(int common_nd,
48 dpctl::tensor::ssize_t const *_offsets,
49 dpctl::tensor::ssize_t const *_shape,
50 dpctl::tensor::ssize_t const *_strides)
51 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
52 strides(_strides)
53 {
54 }
55
56 template <typename nT>
57 size_t operator()(dpctl::tensor::ssize_t gid, nT n) const
58 {
59 dpctl::tensor::ssize_t relative_offset(0);
60 _ind.get_displacement<const dpctl::tensor::ssize_t *,
61 const dpctl::tensor::ssize_t *>(
62 gid, shape, strides + (n * nd), relative_offset);
63
64 return relative_offset + offsets[n];
65 }
66
67private:
68 dpctl::tensor::strides::CIndexer_vector<dpctl::tensor::ssize_t> _ind;
69
70 int nd;
71 dpctl::tensor::ssize_t const *offsets;
72 dpctl::tensor::ssize_t const *shape;
73 dpctl::tensor::ssize_t const *strides;
74};
75
76static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
77
78} // namespace dpnp::extensions::indexing::strides_detail
79
80namespace dpnp::extensions::indexing::kernels
81{
82
83template <typename ProjectorT,
84 typename IndOutIndexerT,
85 typename ChoicesIndexerT,
86 typename IndT,
87 typename T>
89{
90private:
91 const IndT *ind = nullptr;
92 T *dst = nullptr;
93 char **chcs = nullptr;
94 dpctl::tensor::ssize_t n_chcs;
95 const IndOutIndexerT ind_out_indexer;
96 const ChoicesIndexerT chcs_indexer;
97
98public:
99 ChooseFunctor(const IndT *ind_,
100 T *dst_,
101 char **chcs_,
102 dpctl::tensor::ssize_t n_chcs_,
103 const IndOutIndexerT &ind_out_indexer_,
104 const ChoicesIndexerT &chcs_indexer_)
105 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
106 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
107 {
108 }
109
110 void operator()(sycl::id<1> id) const
111 {
112 const ProjectorT proj{};
113
114 dpctl::tensor::ssize_t i = id[0];
115
116 auto ind_dst_offsets = ind_out_indexer(i);
117 dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
118 dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
119
120 IndT chc_idx = ind[ind_offset];
121 // proj produces an index in the range of n_chcs
122 dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
123
124 dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
125
126 T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
127
128 dst[dst_offset] = chc[chc_offset];
129 }
130};
131
132typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
133 size_t,
134 dpctl::tensor::ssize_t,
135 int,
136 const dpctl::tensor::ssize_t *,
137 const char *,
138 char *,
139 char **,
140 dpctl::tensor::ssize_t,
141 dpctl::tensor::ssize_t,
142 const dpctl::tensor::ssize_t *,
143 const std::vector<sycl::event> &);
144
145template <typename ProjectorT, typename indTy, typename Ty>
146sycl::event choose_impl(sycl::queue &q,
147 size_t nelems,
148 dpctl::tensor::ssize_t n_chcs,
149 int nd,
150 const dpctl::tensor::ssize_t *shape_and_strides,
151 const char *ind_cp,
152 char *dst_cp,
153 char **chcs_cp,
154 dpctl::tensor::ssize_t ind_offset,
155 dpctl::tensor::ssize_t dst_offset,
156 const dpctl::tensor::ssize_t *chc_offsets,
157 const std::vector<sycl::event> &depends)
158{
159 dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
160
161 const indTy *ind_tp = reinterpret_cast<const indTy *>(ind_cp);
162 Ty *dst_tp = reinterpret_cast<Ty *>(dst_cp);
163
164 sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
165 cgh.depends_on(depends);
166
167 using InOutIndexerT =
168 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
169 const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
170 shape_and_strides};
171
172 using NthChoiceIndexerT = strides_detail::NthStrideOffsetUnpacked;
173 const NthChoiceIndexerT choices_indexer{
174 nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
175
176 using ChooseFunc = ChooseFunctor<ProjectorT, InOutIndexerT,
177 NthChoiceIndexerT, indTy, Ty>;
178
179 cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
180 ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,
181 ind_out_indexer,
182 choices_indexer));
183 });
184
185 return choose_ev;
186}
187
188} // namespace dpnp::extensions::indexing::kernels