Source code for dpctl.tensor._manipulation_functions

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


import itertools
import operator

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dputils

from ._copy_utils import _broadcast_strides
from ._type_utils import _supported_dtype, _to_device_supported_dtype

__doc__ = (
    "Implementation module for array manipulation "
    "functions in :module:`dpctl.tensor`"
)


def _broadcast_shape_impl(shapes):
    if len(set(shapes)) == 1:
        return shapes[0]
    mutable_shapes = False
    nds = [len(s) for s in shapes]
    biggest = max(nds)
    sh_len = len(shapes)
    for i in range(sh_len):
        diff = biggest - nds[i]
        if diff > 0:
            ty = type(shapes[i])
            shapes[i] = ty(
                itertools.chain(itertools.repeat(1, diff), shapes[i])
            )
    common_shape = []
    for axis in range(biggest):
        lengths = [s[axis] for s in shapes]
        unique = set(lengths + [1])
        if len(unique) > 2:
            raise ValueError(
                "Shape mismatch: two or more arrays have "
                f"incompatible dimensions on axis ({axis},)"
            )
        elif len(unique) == 2:
            unique.remove(1)
            new_length = unique.pop()
            common_shape.append(new_length)
            for i in range(sh_len):
                if shapes[i][axis] == 1:
                    if not mutable_shapes:
                        shapes = [list(s) for s in shapes]
                        mutable_shapes = True
                    shapes[i][axis] = new_length
        else:
            common_shape.append(1)

    return tuple(common_shape)


def _broadcast_shapes(*args):
    """
    Broadcast the input shapes into a single shape;
    returns tuple broadcasted shape.
    """
    array_shapes = [array.shape for array in args]
    return _broadcast_shape_impl(array_shapes)


[docs]def permute_dims(X, /, axes): """permute_dims(x, axes) Permute the axes (dimensions) of an array; returns the permuted array as a view. Args: x (usm_ndarray): input array. axes (Tuple[int, ...]): tuple containing permutation of `(0,1,...,N-1)` where `N` is the number of axes (dimensions) of `x`. Returns: usm_ndarray: An array with permuted axes. The returned array must has the same data type as `x`, is created on the same device as `x` and has the same USM allocation type as `x`. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") axes = normalize_axis_tuple(axes, X.ndim, "axes") if not X.ndim == len(axes): raise ValueError( "The length of the passed axes does not match " "to the number of usm_ndarray dimensions." ) newstrides = tuple(X.strides[i] for i in axes) newshape = tuple(X.shape[i] for i in axes) return dpt.usm_ndarray( shape=newshape, dtype=X.dtype, buffer=X, strides=newstrides, offset=X.__sycl_usm_array_interface__.get("offset", 0), )
[docs]def expand_dims(X, /, *, axis=0): """expand_dims(x, axis) Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis. Args: x (usm_ndarray): input array axis (Union[int, Tuple[int]]): axis position in the expanded axes (zero-based). If `x` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside in the closed-interval `[-N-1, N]`. If provided a negative `axis`, the `axis` position at which to insert a singleton dimension is computed as `N + axis + 1`. Hence, if provided `-1`, the resolved axis position is `N` (i.e., a singleton dimension must be appended to the input array `x`). If provided `-N-1`, the resolved axis position is `0` (i.e., a singleton dimension is prepended to the input array `x`). Returns: usm_ndarray: Returns a view, if possible, and a copy otherwise with the number of dimensions increased. The expanded array has the same data type as the input array `x`. The expanded array is located on the same device as the input array, and has the same USM allocation type. Raises: IndexError: if `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") if type(axis) not in (tuple, list): axis = (axis,) out_ndim = len(axis) + X.ndim axis = normalize_axis_tuple(axis, out_ndim) shape_it = iter(X.shape) shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim)) return dpt.reshape(X, shape)
[docs]def squeeze(X, /, axis=None): """squeeze(x, axis) Removes singleton dimensions (axes) from array `x`. Args: x (usm_ndarray): input array axis (Union[int, Tuple[int,...]]): axis (or axes) to squeeze. Returns: usm_ndarray: Output array is a view, if possible, and a copy otherwise, but with all or a subset of the dimensions of length 1 removed. Output has the same data type as the input, is allocated on the same device as the input and has the same USM allocation type as the input array `x`. Raises: ValueError: if the specified axis has a size greater than one. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") X_shape = X.shape if axis is not None: axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1) new_shape = [] for i, x in enumerate(X_shape): if i not in axis: new_shape.append(x) else: if x != 1: raise ValueError( "Cannot select an axis to squeeze out " "which has size not equal to one." ) new_shape = tuple(new_shape) else: new_shape = tuple(axis for axis in X_shape if axis != 1) if new_shape == X.shape: return X else: return dpt.reshape(X, new_shape)
[docs]def broadcast_to(X, /, shape): """broadcast_to(x, shape) Broadcast an array to a new `shape`; returns the broadcasted :class:`dpctl.tensor.usm_ndarray` as a view. Args: x (usm_ndarray): input array shape (Tuple[int,...]): array shape. The `shape` must be compatible with `x` according to broadcasting rules. Returns: usm_ndarray: An array with the specified `shape`. The output array is a view of the input array, and hence has the same data type, USM allocation type and device attributes. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") # Use numpy.broadcast_to to check the validity of the input # parameter 'shape'. Raise ValueError if 'X' is not compatible # with 'shape' according to NumPy's broadcasting rules. new_array = np.broadcast_to( np.broadcast_to(np.empty(tuple(), dtype="u1"), X.shape), shape ) new_sts = _broadcast_strides(X.shape, X.strides, new_array.ndim) return dpt.usm_ndarray( shape=new_array.shape, dtype=X.dtype, buffer=X, strides=new_sts, offset=X.__sycl_usm_array_interface__.get("offset", 0), )
[docs]def broadcast_arrays(*args): """broadcast_arrays(*arrays) Broadcasts one or more :class:`dpctl.tensor.usm_ndarrays` against one another. Args: arrays (usm_ndarray): an arbitrary number of arrays to be broadcasted. Returns: List[usm_ndarray]: A list of broadcasted arrays. Each array must have the same shape. Each array must have the same `dtype`, `device` and `usm_type` attributes as its corresponding input array. """ for X in args: if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") shape = _broadcast_shapes(*args) if all(X.shape == shape for X in args): return args return [broadcast_to(X, shape) for X in args]
[docs]def flip(X, /, *, axis=None): """flip(x, axis) Reverses the order of elements in an array `x` along the given `axis`. The shape of the array is preserved, but the elements are reordered. Args: x (usm_ndarray): input array. axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along which to flip. If `axis` is `None`, all input array axes are flipped. If `axis` is negative, the flipped axis is counted from the last dimension. If provided more than one axis, only the specified axes are flipped. Default: `None`. Returns: usm_ndarray: A view of `x` with the entries of `axis` reversed. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") X_ndim = X.ndim if axis is None: indexer = (np.s_[::-1],) * X_ndim else: axis = normalize_axis_tuple(axis, X_ndim) indexer = tuple( np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim) ) return X[indexer]
[docs]def roll(X, /, shift, *, axis=None): """ roll(x, shift, axis) Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced at the first position. Array elements that roll beyond the first position are re-introduced at the last position. Args: x (usm_ndarray): input array shift (Union[int, Tuple[int,...]]): number of places by which the elements are shifted. If `shift` is a tuple, then `axis` must be a tuple of the same size, and each of the given axes must be shifted by the corresponding element in `shift`. If `shift` is an `int` and `axis` a tuple, then the same `shift` must be used for all specified axes. If a `shift` is positive, then array elements is shifted positively (toward larger indices) along the dimension of `axis`. If a `shift` is negative, then array elements must be shifted negatively (toward smaller indices) along the dimension of `axis`. axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along which elements to shift. If `axis` is `None`, the array is flattened, shifted, and then restored to its original shape. Default: `None`. Returns: usm_ndarray: An array having the same `dtype`, `usm_type` and `device` attributes as `x` and whose elements are shifted relative to `x`. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") if axis is None: shift = operator.index(shift) res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue ) hev, _ = ti._copy_usm_ndarray_for_roll_1d( src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue ) hev.wait() return res axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True) broadcasted = np.broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError("'shift' and 'axis' should be scalars or 1D sequences") shifts = [ 0, ] * X.ndim for sh, ax in broadcasted: shifts[ax] += sh exec_q = X.sycl_queue res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q ) ht_e, _ = ti._copy_usm_ndarray_for_roll_nd( src=X, dst=res, shifts=shifts, sycl_queue=exec_q ) ht_e.wait() return res
def _arrays_validation(arrays, check_ndim=True): n = len(arrays) if n == 0: raise TypeError("Missing 1 required positional argument: 'arrays'.") if not isinstance(arrays, (list, tuple)): raise TypeError(f"Expected tuple or list type, got {type(arrays)}.") for X in arrays: if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays]) if exec_q is None: raise ValueError("All the input arrays must have same sycl queue.") res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays]) if res_usm_type is None: raise ValueError("All the input arrays must have usm_type.") X0 = arrays[0] _supported_dtype(Xi.dtype for Xi in arrays) res_dtype = X0.dtype dev = exec_q.sycl_device for i in range(1, n): res_dtype = np.promote_types(res_dtype, arrays[i]) res_dtype = _to_device_supported_dtype(res_dtype, dev) if check_ndim: for i in range(1, n): if X0.ndim != arrays[i].ndim: raise ValueError( "All the input arrays must have same number of dimensions, " f"but the array at index 0 has {X0.ndim} dimension(s) and " f"the array at index {i} has {arrays[i].ndim} dimension(s)." ) return res_dtype, res_usm_type, exec_q def _check_same_shapes(X0_shape, axis, n, arrays): for i in range(1, n): Xi_shape = arrays[i].shape for j, X0j in enumerate(X0_shape): if X0j != Xi_shape[j] and j != axis: raise ValueError( "All the input array dimensions for the concatenation " f"axis must match exactly, but along dimension {j}, the " f"array at index 0 has size {X0j} and the array " f"at index {i} has size {Xi_shape[j]}." ) def _concat_axis_None(arrays): "Implementation of concat(arrays, axis=None)." res_dtype, res_usm_type, exec_q = _arrays_validation( arrays, check_ndim=False ) res_shape = 0 for array in arrays: res_shape += array.size res = dpt.empty( res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q ) hev_list = [] fill_start = 0 for array in arrays: fill_end = fill_start + array.size if array.flags.c_contiguous: hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=dpt.reshape(array, -1), dst=res[fill_start:fill_end], sycl_queue=exec_q, ) else: src_ = array # _copy_usm_ndarray_for_reshape requires src and dst to have # the same data type if not array.dtype == res_dtype: src_ = dpt.astype(src_, res_dtype) hev, _ = ti._copy_usm_ndarray_for_reshape( src=src_, dst=res[fill_start:fill_end], sycl_queue=exec_q, ) fill_start = fill_end hev_list.append(hev) dpctl.SyclEvent.wait_for(hev_list) return res
[docs]def concat(arrays, /, *, axis=0): """concat(arrays, axis) Joins a sequence of arrays along an existing axis. Args: arrays (Union[List[usm_ndarray, Tuple[usm_ndarray,...]]]): input arrays to join. The arrays must have the same shape, except in the dimension specified by `axis`. axis (Optional[int]): axis along which the arrays will be joined. If `axis` is `None`, arrays must be flattened before concatenation. If `axis` is negative, it is understood as being counted from the last dimension. Default: `0`. Returns: usm_ndarray: An output array containing the concatenated values. The output array data type is determined by Type Promotion Rules of array API. All input arrays must have the same device attribute. The output array is allocated on that same device, and data movement operations are scheduled on a queue underlying the device. The USM allocation type of the output array is determined by USM allocation type promotion rules. """ if axis is None: return _concat_axis_None(arrays) res_dtype, res_usm_type, exec_q = _arrays_validation(arrays) n = len(arrays) X0 = arrays[0] axis = normalize_axis_index(axis, X0.ndim) X0_shape = X0.shape _check_same_shapes(X0_shape, axis, n, arrays) res_shape_axis = 0 for X in arrays: res_shape_axis = res_shape_axis + X.shape[axis] res_shape = tuple( X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim) ) res = dpt.empty( res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q ) hev_list = [] fill_start = 0 for i in range(n): fill_end = fill_start + arrays[i].shape[axis] c_shapes_copy = tuple( np.s_[fill_start:fill_end] if j == axis else np.s_[:] for j in range(X0.ndim) ) hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q ) fill_start = fill_end hev_list.append(hev) dpctl.SyclEvent.wait_for(hev_list) return res
[docs]def stack(arrays, /, *, axis=0): """ stack(arrays, axis) Joins a sequence of arrays along a new axis. Args: arrays (Union[List[usm_ndarray], Tuple[usm_ndarray,...]]): input arrays to join. Each array must have the same shape. axis (int): axis along which the arrays will be joined. Providing an `axis` specified the index of the new axis in the dimensions of the output array. A valid axis must be on the interval `[-N, N)`, where `N` is the rank (number of dimensions) of `x`. Default: `0`. Returns: usm_ndarray: An output array having rank `N+1`, where `N` is the rank (number of dimensions) of `x`. If the input arrays have different data types, array API Type Promotion Rules apply. Raises: ValueError: if not all input arrays have the same shape IndexError: if provided an `axis` outside of the required interval. """ res_dtype, res_usm_type, exec_q = _arrays_validation(arrays) n = len(arrays) X0 = arrays[0] res_ndim = X0.ndim + 1 axis = normalize_axis_index(axis, res_ndim) X0_shape = X0.shape for i in range(1, n): if X0_shape != arrays[i].shape: raise ValueError("All input arrays must have the same shape") res_shape = tuple( X0_shape[i - 1 * (i >= axis)] if i != axis else n for i in range(res_ndim) ) res = dpt.empty( res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q ) hev_list = [] for i in range(n): c_shapes_copy = tuple( i if j == axis else np.s_[:] for j in range(res_ndim) ) hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q ) hev_list.append(hev) dpctl.SyclEvent.wait_for(hev_list) return res
[docs]def unstack(X, /, *, axis=0): """unstack(x, axis=0) Splits an array in a sequence of arrays along the given axis. Args: x (usm_ndarray): input array axis (int, optional): axis along which `x` is unstacked. If `x` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside in the half-open interval `[-N, N)`. Default: `0`. Returns: Tuple[usm_ndarray,...]: Output sequence of arrays which are views into the input array. Raises: AxisError: if the `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") axis = normalize_axis_index(axis, X.ndim) Y = dpt.moveaxis(X, axis, 0) return tuple(Y[i] for i in range(Y.shape[0]))
[docs]def moveaxis(X, source, destination, /): """moveaxis(x, source, destination) Moves axes of an array to new positions. Args: x (usm_ndarray): input array source (int or a sequence of int): Original positions of the axes to move. These must be unique. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be in the half-open interval `[-N, N)`. destination (int or a sequence of int): Destination positions for each of the original axes. These must also be unique. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be in the half-open interval `[-N, N)`. Returns: usm_ndarray: Array with moved axes. The returned array must has the same data type as `x`, is created on the same device as `x` and has the same USM allocation type as `x`. Raises: AxisError: if `axis` value is invalid. ValueError: if `src` and `dst` have not equal number of elements. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") source = normalize_axis_tuple(source, X.ndim, "source") destination = normalize_axis_tuple(destination, X.ndim, "destination") if len(source) != len(destination): raise ValueError( "`source` and `destination` arguments must have " "the same number of elements" ) ind = [n for n in range(X.ndim) if n not in source] for src, dst in sorted(zip(destination, source)): ind.insert(src, dst) return dpt.permute_dims(X, tuple(ind))
def swapaxes(X, axis1, axis2): """swapaxes(x, axis1, axis2) Interchanges two axes of an array. Args: x (usm_ndarray): input array axis1 (int): First axis. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be in the half-open interval `[-N, N)`. axis2 (int): Second axis. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be in the half-open interval `[-N, N)`. Returns: usm_ndarray: Array with swapped axes. The returned array must has the same data type as `x`, is created on the same device as `x` and has the same USM allocation type as `x`. Raises: AxisError: if `axis` value is invalid. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") axis1 = normalize_axis_index(axis1, X.ndim, "axis1") axis2 = normalize_axis_index(axis2, X.ndim, "axis2") ind = list(range(0, X.ndim)) ind[axis1] = axis2 ind[axis2] = axis1 return dpt.permute_dims(X, tuple(ind))
[docs]def repeat(x, repeats, /, *, axis=None): """repeat(x, repeats, axis=None) Repeat elements of an array on a per-element basis. Args: x (usm_ndarray): input array repeats (Union[int, Sequence[int, ...], usm_ndarray]): The number of repetitions for each element. `repeats` must be broadcast-compatible with `N` where `N` is `prod(x.shape)` if `axis` is `None` and `x.shape[axis]` otherwise. If `repeats` is an array, it must have an integer data type. Otherwise, `repeats` must be a Python integer or sequence of Python integers (i.e., a tuple, list, or range). axis (Optional[int]): The axis along which to repeat values. If `axis` is `None`, the function repeats elements of the flattened array. Default: `None`. Returns: usm_ndarray: output array with repeated elements. If `axis` is `None`, the returned array is one-dimensional, otherwise, it has the same shape as `x`, except for the axis along which elements were repeated. The returned array will have the same data type as `x`. The returned array will be located on the same device as `x` and have the same USM allocation type as `x`. Raises: AxisError: if `axis` value is invalid. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") x_ndim = x.ndim x_shape = x.shape if axis is not None: axis = normalize_axis_index(operator.index(axis), x_ndim) axis_size = x_shape[axis] else: axis_size = x.size scalar = False if isinstance(repeats, int): if repeats < 0: raise ValueError("`repeats` must be a positive integer") usm_type = x.usm_type exec_q = x.sycl_queue scalar = True elif isinstance(repeats, dpt.usm_ndarray): if repeats.ndim > 1: raise ValueError( "`repeats` array must be 0- or 1-dimensional, got " f"{repeats.ndim}" ) exec_q = dpctl.utils.get_execution_queue( (x.sycl_queue, repeats.sycl_queue) ) if exec_q is None: raise dputils.ExecutionPlacementError( "Execution placement can not be unambiguously inferred " "from input arguments." ) usm_type = dpctl.utils.get_coerced_usm_type( ( x.usm_type, repeats.usm_type, ) ) dpctl.utils.validate_usm_type(usm_type, allow_none=False) if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"): raise TypeError( f"`repeats` data type `{repeats.dtype}` cannot be cast to " "`int64` according to the casting rule ''safe.''" ) if repeats.size == 1: scalar = True # bring the single element to the host repeats = int(repeats) if repeats < 0: raise ValueError("`repeats` elements must be positive") else: if repeats.size != axis_size: raise ValueError( "`repeats` array must be broadcastable to the size of " "the repeated axis" ) if not dpt.all(repeats >= 0): raise ValueError("`repeats` elements must be positive") elif isinstance(repeats, (tuple, list, range)): usm_type = x.usm_type exec_q = x.sycl_queue len_reps = len(repeats) if len_reps == 1: repeats = repeats[0] if repeats < 0: raise ValueError("`repeats` elements must be positive") scalar = True else: if len_reps != axis_size: raise ValueError( "`repeats` sequence must have the same length as the " "repeated axis" ) repeats = dpt.asarray( repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q ) if not dpt.all(repeats >= 0): raise ValueError("`repeats` elements must be positive") else: raise TypeError( "Expected int, sequence, or `usm_ndarray` for second argument," f"got {type(repeats)}" ) if scalar: res_axis_size = repeats * axis_size if axis is not None: res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] else: res_shape = (res_axis_size,) res = dpt.empty( res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q ) if res_axis_size > 0: ht_rep_ev, _ = ti._repeat_by_scalar( src=x, dst=res, reps=repeats, axis=axis, sycl_queue=exec_q, ) ht_rep_ev.wait() else: if repeats.dtype != dpt.int64: rep_buf = dpt.empty( repeats.shape, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q, ) ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=repeats, dst=rep_buf, sycl_queue=exec_q ) cumsum = dpt.empty( (axis_size,), dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q, ) # _cumsum_1d synchronizes so `depends` ends here safely res_axis_size = ti._cumsum_1d( rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev] ) if axis is not None: res_shape = ( x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] ) else: res_shape = (res_axis_size,) res = dpt.empty( res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q, ) if res_axis_size > 0: ht_rep_ev, _ = ti._repeat_by_sequence( src=x, dst=res, reps=rep_buf, cumsum=cumsum, axis=axis, sycl_queue=exec_q, ) ht_rep_ev.wait() ht_copy_ev.wait() else: cumsum = dpt.empty( (axis_size,), dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q, ) res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q) if axis is not None: res_shape = ( x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] ) else: res_shape = (res_axis_size,) res = dpt.empty( res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q, ) if res_axis_size > 0: ht_rep_ev, _ = ti._repeat_by_sequence( src=x, dst=res, reps=repeats, cumsum=cumsum, axis=axis, sycl_queue=exec_q, ) ht_rep_ev.wait() return res
[docs]def tile(x, repetitions, /): """tile(x, repetitions) Repeat an input array `x` along each axis a number of times given by `repetitions`. For `N` = len(`repetitions`) and `M` = len(`x.shape`): * If `M < N`, `x` will have `N - M` new axes prepended to its shape * If `M > N`, `repetitions` will have `M - N` ones prepended to it Args: x (usm_ndarray): input array repetitions (Union[int, Tuple[int, ...]]): The number of repetitions along each dimension of `x`. Returns: usm_ndarray: tiled output array. The returned array will have rank `max(M, N)`. If `S` is the shape of `x` after prepending dimensions and `R` is `repetitions` after prepending ones, then the shape of the result will be `S[i] * R[i]` for each dimension `i`. The returned array will have the same data type as `x`. The returned array will be located on the same device as `x` and have the same USM allocation type as `x`. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") if not isinstance(repetitions, tuple): if isinstance(repetitions, int): repetitions = (repetitions,) else: raise TypeError( f"Expected tuple or integer type, got {type(repetitions)}." ) rep_dims = len(repetitions) x_dims = x.ndim if rep_dims < x_dims: repetitions = (x_dims - rep_dims) * (1,) + repetitions elif x_dims < rep_dims: x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape) res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions)) # case of empty input if x.size == 0: return dpt.empty( res_shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue, ) in_sh = x.shape if res_shape == in_sh: return dpt.copy(x) expanded_sh = [] broadcast_sh = [] out_sz = 1 for i in range(len(res_shape)): out_sz *= res_shape[i] reps, sh = repetitions[i], in_sh[i] if reps == 1: # dimension will be unchanged broadcast_sh.append(sh) expanded_sh.append(sh) elif sh == 1: # dimension will be broadcast broadcast_sh.append(reps) expanded_sh.append(sh) else: broadcast_sh.extend([reps, sh]) expanded_sh.extend([1, sh]) exec_q = x.sycl_queue xdt = x.dtype xut = x.usm_type res = dpt.empty((out_sz,), dtype=xdt, usm_type=xut, sycl_queue=exec_q) # no need to copy data for empty output if out_sz > 0: x = dpt.broadcast_to( # this reshape should never copy dpt.reshape(x, expanded_sh), broadcast_sh, ) # copy broadcast input into flat array hev, _ = ti._copy_usm_ndarray_for_reshape( src=x, dst=res, sycl_queue=exec_q ) hev.wait() return dpt.reshape(res, res_shape)