DPNP C++ backend kernel library 0.20.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
putmask_kernel.hpp
1//*****************************************************************************
2// Copyright (c) 2026, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12// - Neither the name of the copyright holder nor the names of its contributors
13// may be used to endorse or promote products derived from this software
14// without specific prior written permission.
15//
16// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26// THE POSSIBILITY OF SUCH DAMAGE.
27//*****************************************************************************
28
29//*****************************************************************************
30// Copyright (c) 2025, Intel Corporation
31// All rights reserved.
32//
33// Redistribution and use in source and binary forms, with or without
34// modification, are permitted provided that the following conditions are met:
35// - Redistributions of source code must retain the above copyright notice,
36// this list of conditions and the following disclaimer.
37// - Redistributions in binary form must reproduce the above copyright notice,
38// this list of conditions and the following disclaimer in the documentation
39// and/or other materials provided with the distribution.
40// - Neither the name of the copyright holder nor the names of its contributors
41// may be used to endorse or promote products derived from this software
42// without specific prior written permission.
43//
44// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
45// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
46// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
47// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
48// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
49// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
50// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
51// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
52// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
53// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
54// THE POSSIBILITY OF SUCH DAMAGE.
55//*****************************************************************************
56
57#pragma once
58
59#include <cstdint>
60#include <type_traits>
61
62#include <sycl/sycl.hpp>
63// dpctl tensor headers
64#include "kernels/alignment.hpp"
65#include "kernels/dpctl_tensor_types.hpp"
66#include "kernels/elementwise_functions/sycl_complex.hpp"
67#include "utils/offset_utils.hpp"
68#include "utils/sycl_utils.hpp"
69#include "utils/type_utils.hpp"
70
71namespace dpnp::extensions::indexing::kernels
72{
73template <typename T,
74 std::uint8_t vec_sz = 4u,
75 std::uint8_t n_vecs = 2u,
76 bool enable_sg_loadstore = true>
78{
79private:
80 T *dst_ = nullptr;
81 const std::uint8_t *mask_u8_ = nullptr;
82 const T *values_ = nullptr;
83 std::size_t nelems_ = 0;
84 std::size_t val_size_ = 0;
85
86public:
88 const bool *mask,
89 const T *values,
90 std::size_t nelems,
91 std::size_t val_size)
92 : dst_(dst), mask_u8_(reinterpret_cast<const std::uint8_t *>(mask)),
93 values_(values), nelems_(nelems), val_size_(val_size)
94 {
95 }
96
97 void operator()(sycl::nd_item<1> ndit) const
98 {
99 if (val_size_ == 0 || nelems_ == 0) {
100 return;
101 }
102
103 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
104 /* Each work-item processes vec_sz elements, contiguous in memory */
105 /* NOTE: work-group size must be divisible by sub-group size */
106
107 using dpctl::tensor::type_utils::is_complex_v;
108 if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
109 auto sg = ndit.get_sub_group();
110 const std::uint32_t sgSize = sg.get_max_local_range()[0];
111 const std::size_t lane_id = sg.get_local_id()[0];
112
113 const std::size_t base =
114 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
115 sg.get_group_id()[0] * sgSize);
116
117 const bool values_no_repeat = (val_size_ >= nelems_);
118
119 if (base + elems_per_wi * sgSize <= nelems_) {
120 using dpctl::tensor::sycl_utils::sub_group_load;
121 using dpctl::tensor::sycl_utils::sub_group_store;
122
123#pragma unroll
124 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
125 const std::size_t offset = base + it * sgSize;
126
127 auto dst_multi_ptr = sycl::address_space_cast<
128 sycl::access::address_space::global_space,
129 sycl::access::decorated::yes>(&dst_[offset]);
130 auto mask_multi_ptr = sycl::address_space_cast<
131 sycl::access::address_space::global_space,
132 sycl::access::decorated::yes>(&mask_u8_[offset]);
133
134 const sycl::vec<T, vec_sz> dst_vec =
135 sub_group_load<vec_sz>(sg, dst_multi_ptr);
136 const sycl::vec<std::uint8_t, vec_sz> mask_vec =
137 sub_group_load<vec_sz>(sg, mask_multi_ptr);
138
139 sycl::vec<T, vec_sz> val_vec;
140
141 if (values_no_repeat) {
142 auto values_multi_ptr = sycl::address_space_cast<
143 sycl::access::address_space::global_space,
144 sycl::access::decorated::yes>(&values_[offset]);
145
146 val_vec = sub_group_load<vec_sz>(sg, values_multi_ptr);
147 }
148 else {
149 const std::size_t idx = offset + lane_id;
150#pragma unroll
151 for (std::uint8_t k = 0; k < vec_sz; ++k) {
152 const std::size_t g =
153 idx + static_cast<std::size_t>(k) * sgSize;
154 val_vec[k] = values_[g % val_size_];
155 }
156 }
157
158 sycl::vec<T, vec_sz> out_vec;
159#pragma unroll
160 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
161 out_vec[vec_id] =
162 (mask_vec[vec_id] != static_cast<std::uint8_t>(0))
163 ? val_vec[vec_id]
164 : dst_vec[vec_id];
165 }
166
167 sub_group_store<vec_sz>(sg, out_vec, dst_multi_ptr);
168 }
169 }
170 else {
171 const std::size_t lane_id = sg.get_local_id()[0];
172 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
173 if (mask_u8_[k]) {
174 const std::size_t v =
175 values_no_repeat ? k : (k % val_size_);
176 dst_[k] = values_[v];
177 }
178 }
179 }
180 }
181 else {
182 const std::size_t gid = ndit.get_global_linear_id();
183 const std::size_t gws = ndit.get_global_range(0);
184
185 const bool values_no_repeat = (val_size_ >= nelems_);
186 for (std::size_t offset = gid; offset < nelems_; offset += gws) {
187 if (mask_u8_[offset]) {
188 const std::size_t v =
189 values_no_repeat ? offset : (offset % val_size_);
190 dst_[offset] = values_[v];
191 }
192 }
193 }
194 }
195};
196
197template <typename T, std::uint8_t vec_sz = 4u, std::uint8_t n_vecs = 2u>
198sycl::event putmask_contig_impl(sycl::queue &exec_q,
199 std::size_t nelems,
200 char *dst_cp,
201 const char *mask_cp,
202 const char *values_cp,
203 std::size_t values_size,
204 const std::vector<sycl::event> &depends = {})
205{
206 T *dst_tp = reinterpret_cast<T *>(dst_cp);
207 const bool *mask_tp = reinterpret_cast<const bool *>(mask_cp);
208 const T *values_tp = reinterpret_cast<const T *>(values_cp);
209
210 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
211 // const std::size_t n_work_items_needed = (nelems + elems_per_wi - 1) /
212 // elems_per_wi;
213 const std::size_t n_work_items_needed = nelems / elems_per_wi;
214 const std::size_t empirical_threshold = std::size_t(1) << 21;
215 const std::size_t lws = (n_work_items_needed <= empirical_threshold)
216 ? std::size_t(128)
217 : std::size_t(256);
218
219 const std::size_t n_groups =
220 ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
221 const auto gws_range = sycl::range<1>(n_groups * lws);
222 const auto lws_range = sycl::range<1>(lws);
223
224 using dpctl::tensor::kernels::alignment_utils::is_aligned;
225 using dpctl::tensor::kernels::alignment_utils::required_alignment;
226
227 const bool aligned = is_aligned<required_alignment>(dst_tp) &&
228 is_aligned<required_alignment>(mask_tp) &&
229 is_aligned<required_alignment>(values_tp);
230
231 sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
232 cgh.depends_on(depends);
233
234 if (aligned) {
235 constexpr bool enable_sg = true;
236 using PutMaskFunc =
237 PutMaskContigFunctor<T, vec_sz, n_vecs, enable_sg>;
238
239 cgh.parallel_for<PutMaskFunc>(
240 sycl::nd_range<1>(gws_range, lws_range),
241 PutMaskFunc(dst_tp, mask_tp, values_tp, nelems, values_size));
242 }
243 else {
244 constexpr bool enable_sg = false;
245 using PutMaskFunc =
246 PutMaskContigFunctor<T, vec_sz, n_vecs, enable_sg>;
247
248 cgh.parallel_for<PutMaskFunc>(
249 sycl::nd_range<1>(gws_range, lws_range),
250 PutMaskFunc(dst_tp, mask_tp, values_tp, nelems, values_size));
251 }
252 });
253
254 return comp_ev;
255}
256
257} // namespace dpnp::extensions::indexing::kernels