73 const argT *in =
nullptr;
74 resT1 *out1 =
nullptr;
75 resT2 *out2 =
nullptr;
82 const std::size_t n_elems)
83 : in(inp), out1(res1), out2(res2), nelems_(n_elems)
87 void operator()(sycl::nd_item<1> ndit)
const
89 static constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
90 UnaryTwoOutputsOpT op{};
94 if constexpr (enable_sg_loadstore &&
95 UnaryTwoOutputsOpT::is_constant::value) {
97 constexpr resT1 const_val1 = UnaryTwoOutputsOpT::constant_value1;
98 constexpr resT2 const_val2 = UnaryTwoOutputsOpT::constant_value2;
100 auto sg = ndit.get_sub_group();
101 const std::uint16_t sgSize = sg.get_max_local_range()[0];
103 const std::size_t base =
104 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
105 sg.get_group_id()[0] * sgSize);
106 if (base + elems_per_wi * sgSize < nelems_) {
107 static constexpr sycl::vec<resT1, vec_sz> res1_vec(const_val1);
108 static constexpr sycl::vec<resT2, vec_sz> res2_vec(const_val2);
110 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
111 const std::size_t offset = base + it * sgSize;
112 auto out1_multi_ptr = sycl::address_space_cast<
113 sycl::access::address_space::global_space,
114 sycl::access::decorated::yes>(&out1[offset]);
115 auto out2_multi_ptr = sycl::address_space_cast<
116 sycl::access::address_space::global_space,
117 sycl::access::decorated::yes>(&out2[offset]);
119 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
120 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
124 const std::size_t lane_id = sg.get_local_id()[0];
125 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
126 out1[k] = const_val1;
127 out2[k] = const_val2;
131 else if constexpr (enable_sg_loadstore &&
132 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
133 UnaryTwoOutputsOpT::supports_vec::value &&
136 auto sg = ndit.get_sub_group();
137 const std::uint16_t sgSize = sg.get_max_local_range()[0];
139 const std::size_t base =
140 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
141 sg.get_group_id()[0] * sgSize);
142 if (base + elems_per_wi * sgSize < nelems_) {
144 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
145 const std::size_t offset = base + it * sgSize;
146 auto in_multi_ptr = sycl::address_space_cast<
147 sycl::access::address_space::global_space,
148 sycl::access::decorated::yes>(&in[offset]);
149 auto out1_multi_ptr = sycl::address_space_cast<
150 sycl::access::address_space::global_space,
151 sycl::access::decorated::yes>(&out1[offset]);
152 auto out2_multi_ptr = sycl::address_space_cast<
153 sycl::access::address_space::global_space,
154 sycl::access::decorated::yes>(&out2[offset]);
156 const sycl::vec<argT, vec_sz> x =
157 sub_group_load<vec_sz>(sg, in_multi_ptr);
158 sycl::vec<resT2, vec_sz> res2_vec = {};
159 const sycl::vec<resT1, vec_sz> res1_vec = op(x, res2_vec);
160 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
161 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
165 const std::size_t lane_id = sg.get_local_id()[0];
166 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
168 out1[k] = op(in[k], out2[k]);
172 else if constexpr (enable_sg_loadstore &&
173 UnaryTwoOutputsOpT::supports_sg_loadstore::value &&
174 std::is_same_v<resT1, argT>)
178 auto sg = ndit.get_sub_group();
179 const std::uint16_t sgSize = sg.get_max_local_range()[0];
180 const std::size_t base =
181 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
182 sg.get_group_id()[0] * sgSize);
184 if (base + elems_per_wi * sgSize < nelems_) {
186 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
187 const std::size_t offset = base + it * sgSize;
188 auto in_multi_ptr = sycl::address_space_cast<
189 sycl::access::address_space::global_space,
190 sycl::access::decorated::yes>(&in[offset]);
191 auto out1_multi_ptr = sycl::address_space_cast<
192 sycl::access::address_space::global_space,
193 sycl::access::decorated::yes>(&out1[offset]);
194 auto out2_multi_ptr = sycl::address_space_cast<
195 sycl::access::address_space::global_space,
196 sycl::access::decorated::yes>(&out2[offset]);
198 sycl::vec<argT, vec_sz> arg_vec =
199 sub_group_load<vec_sz>(sg, in_multi_ptr);
200 sycl::vec<resT2, vec_sz> res2_vec = {};
202 for (std::uint32_t k = 0; k < vec_sz; ++k) {
203 arg_vec[k] = op(arg_vec[k], res2_vec[k]);
205 sub_group_store<vec_sz>(sg, arg_vec, out1_multi_ptr);
206 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
210 const std::size_t lane_id = sg.get_local_id()[0];
211 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
212 out1[k] = op(in[k], out2[k]);
216 else if constexpr (enable_sg_loadstore &&
217 UnaryTwoOutputsOpT::supports_sg_loadstore::value)
221 auto sg = ndit.get_sub_group();
222 const std::uint16_t sgSize = sg.get_max_local_range()[0];
223 const std::size_t base =
224 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
225 sg.get_group_id()[0] * sgSize);
227 if (base + elems_per_wi * sgSize < nelems_) {
229 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
230 const std::size_t offset = base + it * sgSize;
231 auto in_multi_ptr = sycl::address_space_cast<
232 sycl::access::address_space::global_space,
233 sycl::access::decorated::yes>(&in[offset]);
234 auto out1_multi_ptr = sycl::address_space_cast<
235 sycl::access::address_space::global_space,
236 sycl::access::decorated::yes>(&out1[offset]);
237 auto out2_multi_ptr = sycl::address_space_cast<
238 sycl::access::address_space::global_space,
239 sycl::access::decorated::yes>(&out2[offset]);
241 const sycl::vec<argT, vec_sz> arg_vec =
242 sub_group_load<vec_sz>(sg, in_multi_ptr);
243 sycl::vec<resT1, vec_sz> res1_vec = {};
244 sycl::vec<resT2, vec_sz> res2_vec = {};
246 for (std::uint8_t k = 0; k < vec_sz; ++k) {
247 res1_vec[k] = op(arg_vec[k], res2_vec[k]);
249 sub_group_store<vec_sz>(sg, res1_vec, out1_multi_ptr);
250 sub_group_store<vec_sz>(sg, res2_vec, out2_multi_ptr);
254 const std::size_t lane_id = sg.get_local_id()[0];
255 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
256 out1[k] = op(in[k], out2[k]);
261 const std::uint16_t sgSize =
262 ndit.get_sub_group().get_local_range()[0];
263 const std::size_t gid = ndit.get_global_linear_id();
264 const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
266 const std::size_t start =
267 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
268 const std::size_t end = std::min(nelems_, start + elems_per_sg);
269 for (std::size_t offset = start; offset < end; offset += sgSize) {
270 out1[offset] = op(in[offset], out2[offset]);