52    using value_type = _Tp;
 
   53    using difference_type = std::ptrdiff_t;
 
   54    using iterator_category = std::random_access_iterator_tag;
 
   55    using pointer = value_type *;
 
   56    using reference = value_type &;
 
   57    using size_type = shape_elem_type;
 
   61                      const size_type *__shape_stride = 
nullptr,
 
   62                      const size_type *__axes_stride = 
nullptr,
 
   63                      size_type __shape_size = 0)
 
   64        : base(__base_ptr), iter_id(__id), iteration_shape_size(__shape_size),
 
   65          iteration_shape_strides(__shape_stride),
 
   66          axes_shape_strides(__axes_stride)
 
   72    inline reference operator*()
 const 
   77    inline pointer operator->()
 const 
  101        assert(base == __rhs.base); 
 
  103        return (iter_id == __rhs.iter_id);
 
  108        return !(*
this == __rhs);
 
  113        return iter_id < __rhs.iter_id;
 
  119    inline reference operator[](size_type __n)
 const 
  126        difference_type diff =
 
  127            difference_type(iter_id) - difference_type(__rhs.iter_id);
 
  136        const std::vector<size_type> it_strides(__it.iteration_shape_strides,
 
  137                                                __it.iteration_shape_strides +
 
  138                                                    __it.iteration_shape_size);
 
  139        const std::vector<size_type> it_axes_strides(
 
  140            __it.axes_shape_strides,
 
  141            __it.axes_shape_strides + __it.iteration_shape_size);
 
  143        __out << 
"DPNP_USM_iterator(base=" << __it.base;
 
  144        __out << 
", iter_id=" << __it.iter_id;
 
  145        __out << 
", iteration_shape_size=" << __it.iteration_shape_size;
 
  146        __out << 
", iteration_shape_strides=" << it_strides;
 
  147        __out << 
", axes_shape_strides=" << it_axes_strides;
 
 
  154    inline pointer ptr()
 const 
  159    inline pointer ptr(size_type iteration_id)
 const 
  161        size_type offset = 0;
 
  163        if (iteration_shape_size > 0) {
 
  164            long reminder = iteration_id;
 
  165            for (
size_t it = 0; it < static_cast<size_t>(iteration_shape_size);
 
  167                const size_type axis_val = iteration_shape_strides[it];
 
  168                size_type xyz_id = reminder / axis_val;
 
  169                offset += (xyz_id * axes_shape_strides[it]);
 
  171                reminder = reminder % axis_val;
 
  175            offset = iteration_id;
 
  178        return base + offset;
 
  181    const pointer base = 
nullptr;
 
  184    const size_type iteration_shape_size =
 
  187    const size_type *iteration_shape_strides = 
nullptr;
 
  188    const size_type *axes_shape_strides = 
nullptr;
 
 
  203    using value_type = _Tp;
 
  205    using pointer = value_type *;
 
  206    using reference = value_type &;
 
  207    using size_type = shape_elem_type;
 
  211             const size_type *__shape,
 
  212             const size_type __shape_size)
 
  215        std::vector<size_type> shape(__shape, __shape + __shape_size);
 
  216        init_container(__ptr, shape);
 
  221             const size_type *__shape,
 
  222             const size_type *__strides,
 
  223             const size_type __ndim)
 
  226        std::vector<size_type> shape(__shape, __shape + __ndim);
 
  227        std::vector<size_type> strides(__strides, __strides + __ndim);
 
  228        init_container(__ptr, shape, strides);
 
  248             const std::vector<size_type> &__shape)
 
  251        init_container(__ptr, __shape);
 
 
  270             const std::vector<size_type> &__shape,
 
  271             const std::vector<size_type> &__strides)
 
  273        init_container(__ptr, __shape, __strides);
 
 
  289    inline void broadcast_to_shape(
const size_type *__shape,
 
  290                                   const size_type __shape_size)
 
  292        std::vector<size_type> shape(__shape, __shape + __shape_size);
 
  293        broadcast_to_shape(shape);
 
  312        if (broadcastable(input_shape, input_shape_size, __shape)) {
 
  313            free_broadcast_axes_memory();
 
  314            free_output_memory();
 
  316            std::vector<size_type> valid_axes;
 
  317            broadcast_use = 
true;
 
  319            output_shape_size = __shape.size();
 
  320            const size_type output_shape_size_in_bytes =
 
  321                output_shape_size * 
sizeof(size_type);
 
  322            output_shape = 
reinterpret_cast<size_type *
>(
 
  325            for (
int irit = input_shape_size - 1, orit = output_shape_size - 1;
 
  326                 orit >= 0; --irit, --orit)
 
  328                output_shape[orit] = __shape[orit];
 
  332                if (irit < 0 || input_shape[irit] != output_shape[orit]) {
 
  333                    valid_axes.insert(valid_axes.begin(), orit);
 
  337            broadcast_axes_size = valid_axes.size();
 
  338            const size_type broadcast_axes_size_in_bytes =
 
  339                broadcast_axes_size * 
sizeof(size_type);
 
  340            broadcast_axes = 
reinterpret_cast<size_type *
>(
 
  342            std::copy(valid_axes.begin(), valid_axes.end(), broadcast_axes);
 
  345                std::accumulate(output_shape, output_shape + output_shape_size,
 
  346                                size_type(1), std::multiplies<size_type>());
 
  348            output_shape_strides = 
reinterpret_cast<size_type *
>(
 
  351                output_shape, output_shape_size, output_shape_strides);
 
 
  379    inline void set_axes(
const shape_elem_type *__axes, 
const size_t axes_ndim)
 
  381        const std::vector<shape_elem_type> axes_vec(__axes, __axes + axes_ndim);
 
  402    inline void set_axes(
const std::vector<shape_elem_type> &__axes)
 
  408        if (!__axes.empty() && input_shape_size) {
 
  410            free_iteration_memory();
 
  411            free_output_memory();
 
  413            axes = get_validated_axes(__axes, input_shape_size);
 
  416            output_shape_size = input_shape_size - axes.size();
 
  417            const size_type output_shape_size_in_bytes =
 
  418                output_shape_size * 
sizeof(size_type);
 
  420            iteration_shape_size = axes.size();
 
  421            const size_type iteration_shape_size_in_bytes =
 
  422                iteration_shape_size * 
sizeof(size_type);
 
  423            std::vector<size_type> iteration_shape;
 
  425            output_shape = 
reinterpret_cast<size_type *
>(
 
  427            size_type *output_shape_it = output_shape;
 
  428            for (size_type i = 0; i < input_shape_size; ++i) {
 
  429                if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
 
  430                    *output_shape_it = input_shape[i];
 
  436                std::accumulate(output_shape, output_shape + output_shape_size,
 
  437                                size_type(1), std::multiplies<size_type>());
 
  439            output_shape_strides = 
reinterpret_cast<size_type *
>(
 
  442                output_shape, output_shape_size, output_shape_strides);
 
  445            iteration_shape.reserve(iteration_shape_size);
 
  446            for (
const auto &axis : axes) {
 
  447                const size_type axis_dim = input_shape[axis];
 
  448                iteration_shape.push_back(axis_dim);
 
  449                iteration_size *= axis_dim;
 
  452            iteration_shape_strides = 
reinterpret_cast<size_type *
>(
 
  455                                                  iteration_shape.size(),
 
  456                                                  iteration_shape_strides);
 
  458            axes_shape_strides = 
reinterpret_cast<size_type *
>(
 
  460            for (
size_t i = 0; i < static_cast<size_t>(iteration_shape_size);
 
  462                axes_shape_strides[i] = input_shape_strides[axes[i]];
 
 
  470        return iterator(data + get_input_begin_offset(output_global_id), 0,
 
  471                        iteration_shape_strides, axes_shape_strides,
 
  472                        iteration_shape_size);
 
 
  480        return iterator(data + get_input_begin_offset(output_global_id),
 
  481                        get_iteration_size(), iteration_shape_strides,
 
  482                        axes_shape_strides, iteration_shape_size);
 
 
  497    void init_container(pointer __ptr, 
const std::vector<size_type> &__shape)
 
  500        if ((__ptr == 
nullptr) && __shape.empty()) {
 
  504        if (__ptr != 
nullptr) {
 
  512        if (!__shape.empty()) {
 
  514                std::accumulate(__shape.begin(), __shape.end(), size_type(1),
 
  515                                std::multiplies<size_type>());
 
  522            input_shape_size = __shape.size();
 
  524                queue_ref, input_shape_size * 
sizeof(size_type)));
 
  525            std::copy(__shape.begin(), __shape.end(), input_shape);
 
  527            input_shape_strides =
 
  529                    queue_ref, input_shape_size * 
sizeof(size_type)));
 
  531                                                  input_shape_strides);
 
  533        iteration_size = input_size;
 
  536    void init_container(pointer __ptr,
 
  537                        const std::vector<size_type> &__shape,
 
  538                        const std::vector<size_type> &__strides)
 
  541        if ((__ptr == 
nullptr) && __shape.empty()) {
 
  545        if (__ptr != 
nullptr) {
 
  553        if (!__shape.empty()) {
 
  555                std::accumulate(__shape.begin(), __shape.end(), size_type(1),
 
  556                                std::multiplies<size_type>());
 
  563            input_shape_size = __shape.size();
 
  565                queue_ref, input_shape_size * 
sizeof(size_type)));
 
  566            std::copy(__shape.begin(), __shape.end(), input_shape);
 
  568            input_shape_strides =
 
  570                    queue_ref, input_shape_size * 
sizeof(size_type)));
 
  571            std::copy(__strides.begin(), __strides.end(), input_shape_strides);
 
  573        iteration_size = input_size;
 
  577    size_type get_input_begin_offset(size_type output_global_id)
 const 
  579        size_type input_global_id = 0;
 
  581            assert(output_global_id < output_size);
 
  583            for (
size_t iit = 0, oit = 0;
 
  584                 iit < static_cast<size_t>(input_shape_size); ++iit)
 
  586                if (std::find(axes.begin(), axes.end(), iit) == axes.end()) {
 
  588                        output_global_id, output_shape_strides,
 
  589                        output_shape_size, oit);
 
  591                        (output_xyz_id * input_shape_strides[iit]);
 
  596        else if (broadcast_use) {
 
  597            assert(output_global_id < output_size);
 
  598            assert(input_shape_size <= output_shape_size);
 
  600            for (
int irit = input_shape_size - 1, orit = output_shape_size - 1;
 
  601                 irit >= 0; --irit, --orit)
 
  603                size_type *broadcast_axes_end =
 
  604                    broadcast_axes + broadcast_axes_size;
 
  605                if (std::find(broadcast_axes, broadcast_axes_end, orit) ==
 
  606                    broadcast_axes_end) {
 
  608                        output_global_id, output_shape_strides,
 
  609                        output_shape_size, orit);
 
  611                        (output_xyz_id * input_shape_strides[irit]);
 
  616        return input_global_id;
 
  620    size_type get_iteration_size()
 const 
  622        return iteration_size;
 
  625    void free_axes_memory()
 
  628        dpnp_memory_free_c(queue_ref, axes_shape_strides);
 
  629        axes_shape_strides = 
nullptr;
 
  632    void free_broadcast_axes_memory()
 
  634        broadcast_axes_size = size_type{};
 
  635        dpnp_memory_free_c(queue_ref, broadcast_axes);
 
  636        broadcast_axes = 
nullptr;
 
  639    void free_input_memory()
 
  641        input_size = size_type{};
 
  642        input_shape_size = size_type{};
 
  643        dpnp_memory_free_c(queue_ref, input_shape);
 
  644        dpnp_memory_free_c(queue_ref, input_shape_strides);
 
  645        input_shape = 
nullptr;
 
  646        input_shape_strides = 
nullptr;
 
  649    void free_iteration_memory()
 
  651        iteration_size = size_type{};
 
  652        iteration_shape_size = size_type{};
 
  653        dpnp_memory_free_c(queue_ref, iteration_shape_strides);
 
  654        iteration_shape_strides = 
nullptr;
 
  657    void free_output_memory()
 
  659        output_size = size_type{};
 
  660        output_shape_size = size_type{};
 
  661        dpnp_memory_free_c(queue_ref, output_shape);
 
  662        dpnp_memory_free_c(queue_ref, output_shape_strides);
 
  663        output_shape = 
nullptr;
 
  664        output_shape_strides = 
nullptr;
 
  670        free_broadcast_axes_memory();
 
  672        free_iteration_memory();
 
  673        free_output_memory();
 
  676    DPCTLSyclQueueRef queue_ref = 
nullptr; 
 
  678    pointer data = 
nullptr;                   
 
  679    size_type input_size = size_type{};       
 
  680    size_type *input_shape = 
nullptr;         
 
  681    size_type input_shape_size = size_type{}; 
 
  682    size_type *input_shape_strides =
 
  685    std::vector<size_type> axes; 
 
  686    bool axis_use = 
false;
 
  688    size_type *broadcast_axes = 
nullptr; 
 
  689    size_type broadcast_axes_size =
 
  691    bool broadcast_use = 
false;
 
  693    size_type output_size =
 
  695    size_type *output_shape = 
nullptr;         
 
  696    size_type output_shape_size = size_type{}; 
 
  697    size_type *output_shape_strides =
 
  700    size_type iteration_size =
 
  702    size_type iteration_shape_size = size_type{};
 
  703    size_type *iteration_shape_strides = 
nullptr;
 
  704    size_type *axes_shape_strides = 
nullptr;