Source code for dpctl.tensor._reshape

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import operator

import numpy as np

import dpctl.tensor as dpt
from dpctl.tensor._tensor_impl import (
    _copy_usm_ndarray_for_reshape,
    _ravel_multi_index,
    _unravel_index,
)

__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."


def _make_unit_indexes(shape):
    """
    Construct a diagonal matrix with with one on the diagonal
    except if the corresponding element of shape is 1.
    """
    nd = len(shape)
    mi = np.zeros((nd, nd), dtype="u4")
    for i, dim in enumerate(shape):
        mi[i, i] = 1 if dim > 1 else 0
    return mi


def ti_unravel_index(flat_index, shape, order="C"):
    return _unravel_index(flat_index, shape, order)


def ti_ravel_multi_index(multi_index, shape, order="C"):
    return _ravel_multi_index(multi_index, shape, order)


def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
    """
    When reshaping array with `old_sh` shape and `old_sts` strides
    into the new shape `new_sh`, returns the new stride if the reshape
    can be a view, otherwise returns `None`.
    """
    eye_new_mi = _make_unit_indexes(new_sh)
    new_sts = [
        sum(
            st_i * ind_i
            for st_i, ind_i in zip(
                old_sts, ti_unravel_index(flat_index, old_sh, order=order)
            )
        )
        for flat_index in [
            ti_ravel_multi_index(unitvec, new_sh, order=order)
            for unitvec in eye_new_mi
        ]
    ]
    eye_old_mi = _make_unit_indexes(old_sh)
    check_sts = [
        sum(
            st_i * ind_i
            for st_i, ind_i in zip(
                new_sts, ti_unravel_index(flat_index, new_sh, order=order)
            )
        )
        for flat_index in [
            ti_ravel_multi_index(unitvec, old_sh, order=order)
            for unitvec in eye_old_mi
        ]
    ]
    valid = all(
        check_st == old_st or old_dim == 1
        for check_st, old_st, old_dim in zip(check_sts, old_sts, old_sh)
    )
    return new_sts if valid else None


[docs]def reshape(X, /, shape, *, order="C", copy=None): """reshape(x, shape, order="C") Reshapes array ``x`` into new shape. Args: x (usm_ndarray): input array shape (Tuple[int]): the desired shape of the resulting array. order ("C", "F", optional): memory layout of the resulting array if a copy is found to be necessary. Supported choices are ``"C"`` for C-contiguous, or row-major layout; and ``"F"`` for F-contiguous, or column-major layout. Returns: usm_ndarray: Reshaped array is a view, if possible, and a copy otherwise with memory layout as indicated by ``order`` keyword. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError if not isinstance(shape, (list, tuple)): shape = (shape,) if order in "cfCF": order = order.upper() else: raise ValueError( f"Keyword 'order' not recognized. Expecting 'C' or 'F', got {order}" ) if copy not in (True, False, None): raise ValueError( f"Keyword 'copy' not recognized. Expecting True, False, " f"or None, got {copy}" ) shape = [operator.index(d) for d in shape] negative_ones_count = 0 for nshi in shape: if nshi == -1: negative_ones_count = negative_ones_count + 1 if (nshi < -1) or negative_ones_count > 1: raise ValueError( "Target shape should have at most 1 negative " "value which can only be -1" ) if negative_ones_count: sz = -np.prod(shape) if sz == 0: raise ValueError( f"Can not reshape array of size {X.size} into " f"shape {tuple(i for i in shape if i >= 0)}" ) v = X.size // sz shape = [v if d == -1 else d for d in shape] if X.size != np.prod(shape): raise ValueError(f"Can not reshape into {shape}") if X.size: newsts = reshaped_strides(X.shape, X.strides, shape, order=order) else: newsts = (1,) * len(shape) copy_required = newsts is None if copy_required and (copy is False): raise ValueError( "Reshaping the array requires a copy, but no copying was " "requested by using copy=False" ) copy_q = X.sycl_queue if copy_required or (copy is True): # must perform a copy flat_res = dpt.usm_ndarray( (X.size,), dtype=X.dtype, buffer=X.usm_type, buffer_ctor_kwargs={"queue": copy_q}, ) if order == "C": hev, _ = _copy_usm_ndarray_for_reshape( src=X, dst=flat_res, sycl_queue=copy_q ) else: X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1)) hev, _ = _copy_usm_ndarray_for_reshape( src=X_t, dst=flat_res, sycl_queue=copy_q ) hev.wait() return dpt.usm_ndarray( tuple(shape), dtype=X.dtype, buffer=flat_res, order=order ) # can form a view if (len(shape) == X.ndim) and all( s1 == s2 for s1, s2 in zip(shape, X.shape) ): return X return dpt.usm_ndarray( shape, dtype=X.dtype, buffer=X, strides=tuple(newsts), offset=X._element_offset, )