72 const argT *in =
nullptr;
73 resT1 *out1 =
nullptr;
74 resT2 *out2 =
nullptr;
81 const std::size_t n_elems)
82 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
86 void operator()(sycl::nd_item<1> ndit)
const
88 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
89 UnaryTwoOutputsOpT op{};
93 if constexpr (enable_sg_loadstore &&
94 UnaryTwoOutputsOpT::is_constant::value) {
96 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
97 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
99 auto sg = ndit.get_sub_group();
100 const std::uint16_t sgSize = sg.get_max_local_range()[0];
102 const std::size_t base =
103 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
104 sg.get_group_id()[0] * sgSize);
105 if (base + elems_per_wi * sgSize < nelems_) {
106 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
107 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
109 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
110 const std::size_t offset = base + it * sgSize;
111 auto out1_multi_ptr = sycl::address_space_cast<
112 sycl::access::address_space::global_space,
113 sycl::access::decorated::yes>(&out1[offset]);
114 auto out2_multi_ptr = sycl::address_space_cast<
115 sycl::access::address_space::global_space,
116 sycl::access::decorated::yes>(&out2[offset]);
118 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
119 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
123 const std::size_t lane_id = sg.get_local_id()[0];
124 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
125 out1[k] = const_val1;
126 out2[k] = const_val2;
130 else if constexpr (enable_sg_loadstore &&
131 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
132 UnaryTwoOutputsOpT::supports_vec::value &&
135 auto sg = ndit.get_sub_group();
136 const std::uint16_t sgSize = sg.get_max_local_range()[0];
138 const std::size_t base =
139 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
140 sg.get_group_id()[0] * sgSize);
141 if (base + elems_per_wi * sgSize < nelems_) {
143 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
144 const std::size_t offset = base + it * sgSize;
145 auto in_multi_ptr = sycl::address_space_cast<
146 sycl::access::address_space::global_space,
147 sycl::access::decorated::yes>(&in[offset]);
148 auto out1_multi_ptr = sycl::address_space_cast<
149 sycl::access::address_space::global_space,
150 sycl::access::decorated::yes>(&out1[offset]);
151 auto out2_multi_ptr = sycl::address_space_cast<
152 sycl::access::address_space::global_space,
153 sycl::access::decorated::yes>(&out2[offset]);
155 const sycl::vec<argT, vec_sz> x =
156 sub_group_load<vec_sz>(sg, in_multi_ptr);
157 sycl::vec<resT2, vec_sz> res2_vec = {};
158 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
159 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
160 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
164 const std::size_t lane_id = sg.get_local_id()[0];
165 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
167 out1[k] = op(in[k], out2[k]);
171 else if constexpr (enable_sg_loadstore &&
172 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
173 std::is_same_v<resT1, argT>)
177 auto sg = ndit.get_sub_group();
178 const std::uint16_t sgSize = sg.get_max_local_range()[0];
179 const std::size_t base =
180 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
181 sg.get_group_id()[0] * sgSize);
183 if (base + elems_per_wi * sgSize < nelems_) {
185 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
186 const std::size_t offset = base + it * sgSize;
187 auto in_multi_ptr = sycl::address_space_cast<
188 sycl::access::address_space::global_space,
189 sycl::access::decorated::yes>(&in[offset]);
190 auto out1_multi_ptr = sycl::address_space_cast<
191 sycl::access::address_space::global_space,
192 sycl::access::decorated::yes>(&out1[offset]);
193 auto out2_multi_ptr = sycl::address_space_cast<
194 sycl::access::address_space::global_space,
195 sycl::access::decorated::yes>(&out2[offset]);
197 sycl::vec<argT, vec_sz> arg_vec =
198 sub_group_load<vec_sz>(sg, in_multi_ptr);
199 sycl::vec<resT2, vec_sz> res2_vec = {};
201 for (std::uint32_t k = 0; k < vec_sz; ++k) {
202 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
204 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
205 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
209 const std::size_t lane_id = sg.get_local_id()[0];
210 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
211 out1[k] = op(in[k], out2[k]);
215 else if constexpr (enable_sg_loadstore &&
216 UnaryTwoOutputsOpT::supports_sg_loadstore::value)
220 auto sg = ndit.get_sub_group();
221 const std::uint16_t sgSize = sg.get_max_local_range()[0];
222 const std::size_t base =
223 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
224 sg.get_group_id()[0] * sgSize);
226 if (base + elems_per_wi * sgSize < nelems_) {
228 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
229 const std::size_t offset = base + it * sgSize;
230 auto in_multi_ptr = sycl::address_space_cast<
231 sycl::access::address_space::global_space,
232 sycl::access::decorated::yes>(&in[offset]);
233 auto out1_multi_ptr = sycl::address_space_cast<
234 sycl::access::address_space::global_space,
235 sycl::access::decorated::yes>(&out1[offset]);
236 auto out2_multi_ptr = sycl::address_space_cast<
237 sycl::access::address_space::global_space,
238 sycl::access::decorated::yes>(&out2[offset]);
240 const sycl::vec<argT, vec_sz> arg_vec =
241 sub_group_load<vec_sz>(sg, in_multi_ptr);
242 sycl::vec<resT1, vec_sz> res1_vec = {};
243 sycl::vec<resT2, vec_sz> res2_vec = {};
245 for (std::uint8_t k = 0; k < vec_sz; ++k) {
246 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
248 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
249 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
253 const std::size_t lane_id = sg.get_local_id()[0];
254 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
255 out1[k] = op(in[k], out2[k]);
260 const std::uint16_t sgSize =
261 ndit.get_sub_group().get_local_range()[0];
262 const std::size_t gid = ndit.get_global_linear_id();
263 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
265 const std::size_t start =
266 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
267 const std::size_t end = std::min(nelems_, start + elems_per_sg);
268 for (std::size_t offset = start; offset < end; offset += sgSize) {
269 out1[offset] = op(in[offset], out2[offset]);