DPNP C++ backend kernel library 0.20.0dev0
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
choose_kernel.hpp
1//*****************************************************************************
2// Copyright (c) 2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <algorithm>
32#include <complex>
33#include <cstdint>
34#include <limits>
35#include <type_traits>
36
37#include <sycl/sycl.hpp>
38
39#include "kernels/dpctl_tensor_types.hpp"
40#include "utils/indexing_utils.hpp"
41#include "utils/offset_utils.hpp"
42#include "utils/strided_iters.hpp"
43#include "utils/type_utils.hpp"
44
45namespace dpnp::extensions::indexing::strides_detail
46{
47
49{
50 NthStrideOffsetUnpacked(int common_nd,
51 dpctl::tensor::ssize_t const *_offsets,
52 dpctl::tensor::ssize_t const *_shape,
53 dpctl::tensor::ssize_t const *_strides)
54 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
55 strides(_strides)
56 {
57 }
58
59 template <typename nT>
60 size_t operator()(dpctl::tensor::ssize_t gid, nT n) const
61 {
62 dpctl::tensor::ssize_t relative_offset(0);
63 _ind.get_displacement<const dpctl::tensor::ssize_t *,
64 const dpctl::tensor::ssize_t *>(
65 gid, shape, strides + (n * nd), relative_offset);
66
67 return relative_offset + offsets[n];
68 }
69
70private:
71 dpctl::tensor::strides::CIndexer_vector<dpctl::tensor::ssize_t> _ind;
72
73 int nd;
74 dpctl::tensor::ssize_t const *offsets;
75 dpctl::tensor::ssize_t const *shape;
76 dpctl::tensor::ssize_t const *strides;
77};
78
79static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
80
81} // namespace dpnp::extensions::indexing::strides_detail
82
83namespace dpnp::extensions::indexing::kernels
84{
85
86template <typename ProjectorT,
87 typename IndOutIndexerT,
88 typename ChoicesIndexerT,
89 typename IndT,
90 typename T>
92{
93private:
94 const IndT *ind = nullptr;
95 T *dst = nullptr;
96 char **chcs = nullptr;
97 dpctl::tensor::ssize_t n_chcs;
98 const IndOutIndexerT ind_out_indexer;
99 const ChoicesIndexerT chcs_indexer;
100
101public:
102 ChooseFunctor(const IndT *ind_,
103 T *dst_,
104 char **chcs_,
105 dpctl::tensor::ssize_t n_chcs_,
106 const IndOutIndexerT &ind_out_indexer_,
107 const ChoicesIndexerT &chcs_indexer_)
108 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
109 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
110 {
111 }
112
113 void operator()(sycl::id<1> id) const
114 {
115 const ProjectorT proj{};
116
117 dpctl::tensor::ssize_t i = id[0];
118
119 auto ind_dst_offsets = ind_out_indexer(i);
120 dpctl::tensor::ssize_t ind_offset = ind_dst_offsets.get_first_offset();
121 dpctl::tensor::ssize_t dst_offset = ind_dst_offsets.get_second_offset();
122
123 IndT chc_idx = ind[ind_offset];
124 // proj produces an index in the range of n_chcs
125 dpctl::tensor::ssize_t projected_idx = proj(n_chcs, chc_idx);
126
127 dpctl::tensor::ssize_t chc_offset = chcs_indexer(i, projected_idx);
128
129 T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
130
131 dst[dst_offset] = chc[chc_offset];
132 }
133};
134
135typedef sycl::event (*choose_fn_ptr_t)(sycl::queue &,
136 size_t,
137 dpctl::tensor::ssize_t,
138 int,
139 const dpctl::tensor::ssize_t *,
140 const char *,
141 char *,
142 char **,
143 dpctl::tensor::ssize_t,
144 dpctl::tensor::ssize_t,
145 const dpctl::tensor::ssize_t *,
146 const std::vector<sycl::event> &);
147
148template <typename ProjectorT, typename indTy, typename Ty>
149sycl::event choose_impl(sycl::queue &q,
150 size_t nelems,
151 dpctl::tensor::ssize_t n_chcs,
152 int nd,
153 const dpctl::tensor::ssize_t *shape_and_strides,
154 const char *ind_cp,
155 char *dst_cp,
156 char **chcs_cp,
157 dpctl::tensor::ssize_t ind_offset,
158 dpctl::tensor::ssize_t dst_offset,
159 const dpctl::tensor::ssize_t *chc_offsets,
160 const std::vector<sycl::event> &depends)
161{
162 dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
163
164 const indTy *ind_tp = reinterpret_cast<const indTy *>(ind_cp);
165 Ty *dst_tp = reinterpret_cast<Ty *>(dst_cp);
166
167 sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
168 cgh.depends_on(depends);
169
170 using InOutIndexerT =
171 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
172 const InOutIndexerT ind_out_indexer{nd, ind_offset, dst_offset,
173 shape_and_strides};
174
175 using NthChoiceIndexerT = strides_detail::NthStrideOffsetUnpacked;
176 const NthChoiceIndexerT choices_indexer{
177 nd, chc_offsets, shape_and_strides, shape_and_strides + 3 * nd};
178
179 using ChooseFunc = ChooseFunctor<ProjectorT, InOutIndexerT,
180 NthChoiceIndexerT, indTy, Ty>;
181
182 cgh.parallel_for<ChooseFunc>(sycl::range<1>(nelems),
183 ChooseFunc(ind_tp, dst_tp, chcs_cp, n_chcs,
184 ind_out_indexer,
185 choices_indexer));
186 });
187
188 return choose_ev;
189}
190
191} // namespace dpnp::extensions::indexing::kernels