DPNP C++ backend kernel library 0.20.0dev4
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
choose.hpp
1//*****************************************************************************
2// Copyright (c) 2024, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29#pragma once
30
31#include <sycl/sycl.hpp>
32
33#include "kernels/dpctl_tensor_types.hpp"
34#include "utils/strided_iters.hpp"
35
36namespace dpnp::kernels::choose
37{
38using dpctl::tensor::ssize_t;
39
40template <typename ProjectorT,
41 typename IndOutIndexerT,
42 typename ChoicesIndexerT,
43 typename IndT,
44 typename T>
46{
47private:
48 const IndT *ind = nullptr;
49 T *dst = nullptr;
50 char **chcs = nullptr;
51 ssize_t n_chcs;
52 const IndOutIndexerT ind_out_indexer;
53 const ChoicesIndexerT chcs_indexer;
54
55public:
56 ChooseFunctor(const IndT *ind_,
57 T *dst_,
58 char **chcs_,
59 ssize_t n_chcs_,
60 const IndOutIndexerT &ind_out_indexer_,
61 const ChoicesIndexerT &chcs_indexer_)
62 : ind(ind_), dst(dst_), chcs(chcs_), n_chcs(n_chcs_),
63 ind_out_indexer(ind_out_indexer_), chcs_indexer(chcs_indexer_)
64 {
65 }
66
67 void operator()(sycl::id<1> id) const
68 {
69 const ProjectorT proj{};
70
71 ssize_t i = id[0];
72
73 auto ind_dst_offsets = ind_out_indexer(i);
74 ssize_t ind_offset = ind_dst_offsets.get_first_offset();
75 ssize_t dst_offset = ind_dst_offsets.get_second_offset();
76
77 IndT chc_idx = ind[ind_offset];
78 // proj produces an index in the range of n_chcs
79 ssize_t projected_idx = proj(n_chcs, chc_idx);
80
81 ssize_t chc_offset = chcs_indexer(i, projected_idx);
82
83 T *chc = reinterpret_cast<T *>(chcs[projected_idx]);
84
85 dst[dst_offset] = chc[chc_offset];
86 }
87};
88
89namespace strides
90{
91using dpctl::tensor::strides::CIndexer_vector;
92
94{
95 NthStrideOffsetUnpacked(int common_nd,
96 ssize_t const *_offsets,
97 ssize_t const *_shape,
98 ssize_t const *_strides)
99 : _ind(common_nd), nd(common_nd), offsets(_offsets), shape(_shape),
100 strides(_strides)
101 {
102 }
103
104 template <typename nT>
105 size_t operator()(ssize_t gid, nT n) const
106 {
107 ssize_t relative_offset(0);
108 _ind.get_displacement<const ssize_t *, const ssize_t *>(
109 gid, shape, strides + (n * nd), relative_offset);
110
111 return relative_offset + offsets[n];
112 }
113
114private:
115 CIndexer_vector<ssize_t> _ind;
116
117 int nd;
118 ssize_t const *offsets;
119 ssize_t const *shape;
120 ssize_t const *strides;
121};
122
123static_assert(sycl::is_device_copyable_v<NthStrideOffsetUnpacked>);
124
125} // namespace strides
126} // namespace dpnp::kernels::choose