52 using value_type = _Tp;
53 using difference_type = std::ptrdiff_t;
54 using iterator_category = std::random_access_iterator_tag;
55 using pointer = value_type *;
56 using reference = value_type &;
57 using size_type = shape_elem_type;
61 const size_type *__shape_stride =
nullptr,
62 const size_type *__axes_stride =
nullptr,
63 size_type __shape_size = 0)
64 : base(__base_ptr), iter_id(__id), iteration_shape_size(__shape_size),
65 iteration_shape_strides(__shape_stride),
66 axes_shape_strides(__axes_stride)
72 inline reference operator*()
const
77 inline pointer operator->()
const
101 assert(base == __rhs.base);
103 return (iter_id == __rhs.iter_id);
108 return !(*
this == __rhs);
113 return iter_id < __rhs.iter_id;
119 inline reference operator[](size_type __n)
const
126 difference_type diff =
127 difference_type(iter_id) - difference_type(__rhs.iter_id);
136 const std::vector<size_type> it_strides(__it.iteration_shape_strides,
137 __it.iteration_shape_strides +
138 __it.iteration_shape_size);
139 const std::vector<size_type> it_axes_strides(
140 __it.axes_shape_strides,
141 __it.axes_shape_strides + __it.iteration_shape_size);
143 __out <<
"DPNP_USM_iterator(base=" << __it.base;
144 __out <<
", iter_id=" << __it.iter_id;
145 __out <<
", iteration_shape_size=" << __it.iteration_shape_size;
146 __out <<
", iteration_shape_strides=" << it_strides;
147 __out <<
", axes_shape_strides=" << it_axes_strides;
154 inline pointer ptr()
const
159 inline pointer ptr(size_type iteration_id)
const
161 size_type offset = 0;
163 if (iteration_shape_size > 0) {
164 long reminder = iteration_id;
165 for (
size_t it = 0; it < static_cast<size_t>(iteration_shape_size);
167 const size_type axis_val = iteration_shape_strides[it];
168 size_type xyz_id = reminder / axis_val;
169 offset += (xyz_id * axes_shape_strides[it]);
171 reminder = reminder % axis_val;
175 offset = iteration_id;
178 return base + offset;
181 const pointer base =
nullptr;
184 const size_type iteration_shape_size =
187 const size_type *iteration_shape_strides =
nullptr;
188 const size_type *axes_shape_strides =
nullptr;
203 using value_type = _Tp;
205 using pointer = value_type *;
206 using reference = value_type &;
207 using size_type = shape_elem_type;
211 const size_type *__shape,
212 const size_type __shape_size)
215 std::vector<size_type> shape(__shape, __shape + __shape_size);
216 init_container(__ptr, shape);
221 const size_type *__shape,
222 const size_type *__strides,
223 const size_type __ndim)
226 std::vector<size_type> shape(__shape, __shape + __ndim);
227 std::vector<size_type> strides(__strides, __strides + __ndim);
228 init_container(__ptr, shape, strides);
248 const std::vector<size_type> &__shape)
251 init_container(__ptr, __shape);
270 const std::vector<size_type> &__shape,
271 const std::vector<size_type> &__strides)
273 init_container(__ptr, __shape, __strides);
289 inline void broadcast_to_shape(
const size_type *__shape,
290 const size_type __shape_size)
292 std::vector<size_type> shape(__shape, __shape + __shape_size);
293 broadcast_to_shape(shape);
312 if (broadcastable(input_shape, input_shape_size, __shape)) {
313 free_broadcast_axes_memory();
314 free_output_memory();
316 std::vector<size_type> valid_axes;
317 broadcast_use =
true;
319 output_shape_size = __shape.size();
320 const size_type output_shape_size_in_bytes =
321 output_shape_size *
sizeof(size_type);
322 output_shape =
reinterpret_cast<size_type *
>(
325 for (
int irit = input_shape_size - 1, orit = output_shape_size - 1;
326 orit >= 0; --irit, --orit)
328 output_shape[orit] = __shape[orit];
332 if (irit < 0 || input_shape[irit] != output_shape[orit]) {
333 valid_axes.insert(valid_axes.begin(), orit);
337 broadcast_axes_size = valid_axes.size();
338 const size_type broadcast_axes_size_in_bytes =
339 broadcast_axes_size *
sizeof(size_type);
340 broadcast_axes =
reinterpret_cast<size_type *
>(
342 std::copy(valid_axes.begin(), valid_axes.end(), broadcast_axes);
345 std::accumulate(output_shape, output_shape + output_shape_size,
346 size_type(1), std::multiplies<size_type>());
348 output_shape_strides =
reinterpret_cast<size_type *
>(
351 output_shape, output_shape_size, output_shape_strides);
379 inline void set_axes(
const shape_elem_type *__axes,
const size_t axes_ndim)
381 const std::vector<shape_elem_type> axes_vec(__axes, __axes + axes_ndim);
402 inline void set_axes(
const std::vector<shape_elem_type> &__axes)
408 if (!__axes.empty() && input_shape_size) {
410 free_iteration_memory();
411 free_output_memory();
413 axes = get_validated_axes(__axes, input_shape_size);
416 output_shape_size = input_shape_size - axes.size();
417 const size_type output_shape_size_in_bytes =
418 output_shape_size *
sizeof(size_type);
420 iteration_shape_size = axes.size();
421 const size_type iteration_shape_size_in_bytes =
422 iteration_shape_size *
sizeof(size_type);
423 std::vector<size_type> iteration_shape;
425 output_shape =
reinterpret_cast<size_type *
>(
427 size_type *output_shape_it = output_shape;
428 for (size_type i = 0; i < input_shape_size; ++i) {
429 if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
430 *output_shape_it = input_shape[i];
436 std::accumulate(output_shape, output_shape + output_shape_size,
437 size_type(1), std::multiplies<size_type>());
439 output_shape_strides =
reinterpret_cast<size_type *
>(
442 output_shape, output_shape_size, output_shape_strides);
445 iteration_shape.reserve(iteration_shape_size);
446 for (
const auto &axis : axes) {
447 const size_type axis_dim = input_shape[axis];
448 iteration_shape.push_back(axis_dim);
449 iteration_size *= axis_dim;
452 iteration_shape_strides =
reinterpret_cast<size_type *
>(
455 iteration_shape.size(),
456 iteration_shape_strides);
458 axes_shape_strides =
reinterpret_cast<size_type *
>(
460 for (
size_t i = 0; i < static_cast<size_t>(iteration_shape_size);
462 axes_shape_strides[i] = input_shape_strides[axes[i]];
470 return iterator(data + get_input_begin_offset(output_global_id), 0,
471 iteration_shape_strides, axes_shape_strides,
472 iteration_shape_size);
480 return iterator(data + get_input_begin_offset(output_global_id),
481 get_iteration_size(), iteration_shape_strides,
482 axes_shape_strides, iteration_shape_size);
497 void init_container(pointer __ptr,
const std::vector<size_type> &__shape)
500 if ((__ptr ==
nullptr) && __shape.empty()) {
504 if (__ptr !=
nullptr) {
512 if (!__shape.empty()) {
514 std::accumulate(__shape.begin(), __shape.end(), size_type(1),
515 std::multiplies<size_type>());
522 input_shape_size = __shape.size();
524 queue_ref, input_shape_size *
sizeof(size_type)));
525 std::copy(__shape.begin(), __shape.end(), input_shape);
527 input_shape_strides =
529 queue_ref, input_shape_size *
sizeof(size_type)));
531 input_shape_strides);
533 iteration_size = input_size;
536 void init_container(pointer __ptr,
537 const std::vector<size_type> &__shape,
538 const std::vector<size_type> &__strides)
541 if ((__ptr ==
nullptr) && __shape.empty()) {
545 if (__ptr !=
nullptr) {
553 if (!__shape.empty()) {
555 std::accumulate(__shape.begin(), __shape.end(), size_type(1),
556 std::multiplies<size_type>());
563 input_shape_size = __shape.size();
565 queue_ref, input_shape_size *
sizeof(size_type)));
566 std::copy(__shape.begin(), __shape.end(), input_shape);
568 input_shape_strides =
570 queue_ref, input_shape_size *
sizeof(size_type)));
571 std::copy(__strides.begin(), __strides.end(), input_shape_strides);
573 iteration_size = input_size;
577 size_type get_input_begin_offset(size_type output_global_id)
const
579 size_type input_global_id = 0;
581 assert(output_global_id < output_size);
583 for (
size_t iit = 0, oit = 0;
584 iit < static_cast<size_t>(input_shape_size); ++iit)
586 if (std::find(axes.begin(), axes.end(), iit) == axes.end()) {
588 output_global_id, output_shape_strides,
589 output_shape_size, oit);
591 (output_xyz_id * input_shape_strides[iit]);
596 else if (broadcast_use) {
597 assert(output_global_id < output_size);
598 assert(input_shape_size <= output_shape_size);
600 for (
int irit = input_shape_size - 1, orit = output_shape_size - 1;
601 irit >= 0; --irit, --orit)
603 size_type *broadcast_axes_end =
604 broadcast_axes + broadcast_axes_size;
605 if (std::find(broadcast_axes, broadcast_axes_end, orit) ==
606 broadcast_axes_end) {
608 output_global_id, output_shape_strides,
609 output_shape_size, orit);
611 (output_xyz_id * input_shape_strides[irit]);
616 return input_global_id;
620 size_type get_iteration_size()
const
622 return iteration_size;
625 void free_axes_memory()
628 dpnp_memory_free_c(queue_ref, axes_shape_strides);
629 axes_shape_strides =
nullptr;
632 void free_broadcast_axes_memory()
634 broadcast_axes_size = size_type{};
635 dpnp_memory_free_c(queue_ref, broadcast_axes);
636 broadcast_axes =
nullptr;
639 void free_input_memory()
641 input_size = size_type{};
642 input_shape_size = size_type{};
643 dpnp_memory_free_c(queue_ref, input_shape);
644 dpnp_memory_free_c(queue_ref, input_shape_strides);
645 input_shape =
nullptr;
646 input_shape_strides =
nullptr;
649 void free_iteration_memory()
651 iteration_size = size_type{};
652 iteration_shape_size = size_type{};
653 dpnp_memory_free_c(queue_ref, iteration_shape_strides);
654 iteration_shape_strides =
nullptr;
657 void free_output_memory()
659 output_size = size_type{};
660 output_shape_size = size_type{};
661 dpnp_memory_free_c(queue_ref, output_shape);
662 dpnp_memory_free_c(queue_ref, output_shape_strides);
663 output_shape =
nullptr;
664 output_shape_strides =
nullptr;
670 free_broadcast_axes_memory();
672 free_iteration_memory();
673 free_output_memory();
676 DPCTLSyclQueueRef queue_ref =
nullptr;
678 pointer data =
nullptr;
679 size_type input_size = size_type{};
680 size_type *input_shape =
nullptr;
681 size_type input_shape_size = size_type{};
682 size_type *input_shape_strides =
685 std::vector<size_type> axes;
686 bool axis_use =
false;
688 size_type *broadcast_axes =
nullptr;
689 size_type broadcast_axes_size =
691 bool broadcast_use =
false;
693 size_type output_size =
695 size_type *output_shape =
nullptr;
696 size_type output_shape_size = size_type{};
697 size_type *output_shape_strides =
700 size_type iteration_size =
702 size_type iteration_shape_size = size_type{};
703 size_type *iteration_shape_strides =
nullptr;
704 size_type *axes_shape_strides =
nullptr;