81 const std::uint8_t *mask_u8_ =
nullptr;
82 const T *values_ =
nullptr;
83 std::size_t nelems_ = 0;
84 std::size_t val_size_ = 0;
92 : dst_(dst), mask_u8_(
reinterpret_cast<const std::uint8_t *
>(mask)),
93 values_(values), nelems_(nelems), val_size_(val_size)
97 void operator()(sycl::nd_item<1> ndit)
const
99 if (val_size_ == 0 || nelems_ == 0) {
103 constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
107 using dpctl::tensor::type_utils::is_complex_v;
108 if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
109 auto sg = ndit.get_sub_group();
110 const std::uint32_t sgSize = sg.get_max_local_range()[0];
111 const std::size_t lane_id = sg.get_local_id()[0];
113 const std::size_t base =
114 elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
115 sg.get_group_id()[0] * sgSize);
117 const bool values_no_repeat = (val_size_ >= nelems_);
119 if (base + elems_per_wi * sgSize <= nelems_) {
120 using dpctl::tensor::sycl_utils::sub_group_load;
121 using dpctl::tensor::sycl_utils::sub_group_store;
124 for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
125 const std::size_t offset = base + it * sgSize;
127 auto dst_multi_ptr = sycl::address_space_cast<
128 sycl::access::address_space::global_space,
129 sycl::access::decorated::yes>(&dst_[offset]);
130 auto mask_multi_ptr = sycl::address_space_cast<
131 sycl::access::address_space::global_space,
132 sycl::access::decorated::yes>(&mask_u8_[offset]);
134 const sycl::vec<T, vec_sz> dst_vec =
135 sub_group_load<vec_sz>(sg, dst_multi_ptr);
136 const sycl::vec<std::uint8_t, vec_sz> mask_vec =
137 sub_group_load<vec_sz>(sg, mask_multi_ptr);
139 sycl::vec<T, vec_sz> val_vec;
141 if (values_no_repeat) {
142 auto values_multi_ptr = sycl::address_space_cast<
143 sycl::access::address_space::global_space,
144 sycl::access::decorated::yes>(&values_[offset]);
146 val_vec = sub_group_load<vec_sz>(sg, values_multi_ptr);
149 const std::size_t idx = offset + lane_id;
151 for (std::uint8_t k = 0; k < vec_sz; ++k) {
152 const std::size_t g =
153 idx +
static_cast<std::size_t
>(k) * sgSize;
154 val_vec[k] = values_[g % val_size_];
158 sycl::vec<T, vec_sz> out_vec;
160 for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
162 (mask_vec[vec_id] !=
static_cast<std::uint8_t
>(0))
167 sub_group_store<vec_sz>(sg, out_vec, dst_multi_ptr);
171 const std::size_t lane_id = sg.get_local_id()[0];
172 for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
174 const std::size_t v =
175 values_no_repeat ? k : (k % val_size_);
176 dst_[k] = values_[v];
182 const std::size_t gid = ndit.get_global_linear_id();
183 const std::size_t gws = ndit.get_global_range(0);
185 const bool values_no_repeat = (val_size_ >= nelems_);
186 for (std::size_t offset = gid; offset < nelems_; offset += gws) {
187 if (mask_u8_[offset]) {
188 const std::size_t v =
189 values_no_repeat ? offset : (offset % val_size_);
190 dst_[offset] = values_[v];