DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
choose.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <cstddef>
32
33#include <sycl/sycl.hpp>
34
35#include "kernels/dpctl_tensor_types.hpp"
36#include "utils/strided_iters.hpp"
37
38namespace dpnp::kernels::choose
39{
40using dpctl::tensor::ssize_t;
41
42template <typename ProjectorT,
43 typename IndOutIndexerT,
44 typename ChoicesIndexerT,
45 typename IndT,
46 typename T>
48{
49private:
50 const IndT *ind = nullptr;
51 T *dst = nullptr;
52 char **chcs = nullptr;
53 ssize_t n_chcs;
54 const IndOutIndexerT ind_out_indexer;
55 const ChoicesIndexerT chcs_indexer;
56
57public:
58 ChooseFunctor(const IndT *ind_,
59 T *dst_,
60 char **chcs_,
61 ssize_t n_chcs_,
62 const IndOutIndexerT &ind_out_indexer_,
63 const ChoicesIndexerT &chcs_indexer_)
64 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
65 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
66 {
67 }
68
69 void operator()(sycl::id<1> id) const
70 {
71 const ProjectorT proj{};
72
73 ssize_t i = id[0];
74
75 auto ind_dst_offsets = ind_out_indexer(i);
76 ssize_t ind_offset = ind_dst_offsets.get_first_offset();
77 ssize_t dst_offset = ind_dst_offsets.get_second_offset();
78
79 IndT chc_idx = ind[ind_offset];
80 // proj produces an index in the range of n_chcs
81 ssize_t projected_idx = proj(n_chcs, chc_idx);
82
83 ssize_t chc_offset = chcs_indexer(i, projected_idx);
84
85 T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
86
87 dst[dst_offset] = chc[chc_offset];
88 }
89};
90
91namespace strides
92{
93using dpctl::tensor::strides::CIndexer_vector;
94
96{
97 NthStrideOffsetUnpacked(int common_nd,
98 ssize_t const *_offsets,
99 ssize_t const *_shape,
100 ssize_t const *_strides)
101 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
102 strides(_strides)
103 {
104 }
105
106 template <typename nT>
107 size_t operator()(ssize_t gid, nT n) const
108 {
109 ssize_t relative_offset(0);
110 _ind.get_displacement<const ssize_t *, const ssize_t *>(
111 gid, shape, strides + (n * nd), relative_offset);
112
113 return relative_offset + offsets[n];
114 }
115
116private:
117 CIndexer_vector<ssize_t> _ind;
118
119 int nd;
120 ssize_t const *offsets;
121 ssize_t const *shape;
122 ssize_t const *strides;
123};
124
125static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
126
127} // namespace strides
128} // namespace dpnp::kernels::choose