#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2023 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from numpy.core.numeric import normalize_axis_tuple
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from ._type_utils import _to_device_supported_dtype
def _default_reduction_dtype(inp_dt, q):
    """Gives default output data type for given input data
    type `inp_dt` when reduction is performed on queue `q`
    """
    inp_kind = inp_dt.kind
    if inp_kind in "bi":
        res_dt = dpt.dtype(ti.default_device_int_type(q))
        if inp_dt.itemsize > res_dt.itemsize:
            res_dt = inp_dt
    elif inp_kind in "u":
        res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
        res_ii = dpt.iinfo(res_dt)
        inp_ii = dpt.iinfo(inp_dt)
        if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
            pass
        else:
            res_dt = inp_dt
    elif inp_kind in "f":
        res_dt = dpt.dtype(ti.default_device_fp_type(q))
        if res_dt.itemsize < inp_dt.itemsize:
            res_dt = inp_dt
    elif inp_kind in "c":
        res_dt = dpt.dtype(ti.default_device_complex_type(q))
        if res_dt.itemsize < inp_dt.itemsize:
            res_dt = inp_dt
    return res_dt
def _reduction_over_axis(
    x,
    axis,
    dtype,
    keepdims,
    _reduction_fn,
    _dtype_supported,
    _default_reduction_type_fn,
    _identity=None,
):
    if not isinstance(x, dpt.usm_ndarray):
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
    nd = x.ndim
    if axis is None:
        axis = tuple(range(nd))
    if not isinstance(axis, (tuple, list)):
        axis = (axis,)
    axis = normalize_axis_tuple(axis, nd, "axis")
    red_nd = len(axis)
    perm = [i for i in range(nd) if i not in axis] + list(axis)
    arr2 = dpt.permute_dims(x, perm)
    res_shape = arr2.shape[: nd - red_nd]
    q = x.sycl_queue
    inp_dt = x.dtype
    if dtype is None:
        res_dt = _default_reduction_type_fn(inp_dt, q)
    else:
        res_dt = dpt.dtype(dtype)
        res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
    res_usm_type = x.usm_type
    if x.size == 0:
        if _identity is None:
            raise ValueError("reduction does not support zero-size arrays")
        else:
            if keepdims:
                res_shape = res_shape + (1,) * red_nd
                inv_perm = sorted(range(nd), key=lambda d: perm[d])
                res_shape = tuple(res_shape[i] for i in inv_perm)
            return dpt.full(
                res_shape,
                _identity,
                dtype=res_dt,
                usm_type=res_usm_type,
                sycl_queue=q,
            )
    if red_nd == 0:
        return dpt.astype(x, res_dt, copy=False)
    host_tasks_list = []
    if _dtype_supported(inp_dt, res_dt, res_usm_type, q):
        res = dpt.empty(
            res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
        )
        ht_e, _ = _reduction_fn(
            src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q
        )
        host_tasks_list.append(ht_e)
    else:
        if dtype is None:
            raise RuntimeError(
                "Automatically determined reduction data type does not "
                "have direct implementation"
            )
        tmp_dt = _default_reduction_dtype(inp_dt, q)
        tmp = dpt.empty(
            res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
        )
        ht_e_tmp, r_e = _reduction_fn(
            src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
        )
        host_tasks_list.append(ht_e_tmp)
        res = dpt.empty(
            res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
        )
        ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray(
            src=tmp, dst=res, sycl_queue=q, depends=[r_e]
        )
        host_tasks_list.append(ht_e)
    if keepdims:
        res_shape = res_shape + (1,) * red_nd
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
        res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
    dpctl.SyclEvent.wait_for(host_tasks_list)
    return res
[docs]def sum(x, axis=None, dtype=None, keepdims=False):
    """sum(x, axis=None, dtype=None, keepdims=False)
    Calculates the sum of elements in the input array `x`.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int, Tuple[int, ...]]):
            axis or axes along which sums must be computed. If a tuple
            of unique integers, sums are computed over multiple axes.
            If `None`, the sum is computed over the entire array.
            Default: `None`.
        dtype (Optional[dtype]):
            data type of the returned array. If `None`, the default data
            type is inferred from the "kind" of the input array data type.
                * If `x` has a real-valued floating-point data type,
                  the returned array will have the default real-valued
                  floating-point data type for the device where input
                  array `x` is allocated.
                * If x` has signed integral data type, the returned array
                  will have the default signed integral type for the device
                  where input array `x` is allocated.
                * If `x` has unsigned integral data type, the returned array
                  will have the default unsigned integral type for the device
                  where input array `x` is allocated.
                * If `x` has a complex-valued floating-point data typee,
                  the returned array will have the default complex-valued
                  floating-pointer data type for the device where input
                  array `x` is allocated.
                * If `x` has a boolean data type, the returned array will
                  have the default signed integral type for the device
                  where input array `x` is allocated.
            If the data type (either specified or resolved) differs from the
            data type of `x`, the input array elements are cast to the
            specified data type before computing the sum. Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the sums. If the sum was computed over the
            entire array, a zero-dimensional array is returned. The returned
            array has the data type as described in the `dtype` parameter
            description above.
    """
    return _reduction_over_axis(
        x,
        axis,
        dtype,
        keepdims,
        ti._sum_over_axis,
        ti._sum_over_axis_dtype_supported,
        _default_reduction_dtype,
        _identity=0,
    ) 
[docs]def prod(x, axis=None, dtype=None, keepdims=False):
    """prod(x, axis=None, dtype=None, keepdims=False)
    Calculates the product of elements in the input array `x`.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int, Tuple[int, ...]]):
            axis or axes along which products must be computed. If a tuple
            of unique integers, products are computed over multiple axes.
            If `None`, the product is computed over the entire array.
            Default: `None`.
        dtype (Optional[dtype]):
            data type of the returned array. If `None`, the default data
            type is inferred from the "kind" of the input array data type.
                * If `x` has a real-valued floating-point data type,
                  the returned array will have the default real-valued
                  floating-point data type for the device where input
                  array `x` is allocated.
                * If x` has signed integral data type, the returned array
                  will have the default signed integral type for the device
                  where input array `x` is allocated.
                * If `x` has unsigned integral data type, the returned array
                  will have the default unsigned integral type for the device
                  where input array `x` is allocated.
                * If `x` has a complex-valued floating-point data typee,
                  the returned array will have the default complex-valued
                  floating-pointer data type for the device where input
                  array `x` is allocated.
                * If `x` has a boolean data type, the returned array will
                  have the default signed integral type for the device
                  where input array `x` is allocated.
            If the data type (either specified or resolved) differs from the
            data type of `x`, the input array elements are cast to the
            specified data type before computing the product. Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the products. If the product was computed over
            the entire array, a zero-dimensional array is returned. The returned
            array has the data type as described in the `dtype` parameter
            description above.
    """
    return _reduction_over_axis(
        x,
        axis,
        dtype,
        keepdims,
        ti._prod_over_axis,
        ti._prod_over_axis_dtype_supported,
        _default_reduction_dtype,
        _identity=1,
    ) 
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
    if not isinstance(x, dpt.usm_ndarray):
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
    nd = x.ndim
    if axis is None:
        axis = tuple(range(nd))
    if not isinstance(axis, (tuple, list)):
        axis = (axis,)
    axis = normalize_axis_tuple(axis, nd, "axis")
    red_nd = len(axis)
    perm = [i for i in range(nd) if i not in axis] + list(axis)
    x_tmp = dpt.permute_dims(x, perm)
    res_shape = x_tmp.shape[: nd - red_nd]
    exec_q = x.sycl_queue
    res_dt = x.dtype
    res_usm_type = x.usm_type
    if x.size == 0:
        raise ValueError("reduction does not support zero-size arrays")
    if red_nd == 0:
        return x
    res = dpt.empty(
        res_shape,
        dtype=res_dt,
        usm_type=res_usm_type,
        sycl_queue=exec_q,
    )
    hev, _ = _reduction_fn(
        src=x_tmp,
        trailing_dims_to_reduce=red_nd,
        dst=res,
        sycl_queue=exec_q,
    )
    if keepdims:
        res_shape = res_shape + (1,) * red_nd
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
        res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
    hev.wait()
    return res
[docs]def max(x, axis=None, keepdims=False):
    """max(x, axis=None, dtype=None, keepdims=False)
    Calculates the maximum value of the input array `x`.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int, Tuple[int, ...]]):
            axis or axes along which maxima must be computed. If a tuple
            of unique integers, the maxima are computed over multiple axes.
            If `None`, the max is computed over the entire array.
            Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the maxima. If the max was computed over the
            entire array, a zero-dimensional array is returned. The returned
            array has the same data type as `x`.
    """
    return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis) 
[docs]def min(x, axis=None, keepdims=False):
    """min(x, axis=None, dtype=None, keepdims=False)
    Calculates the minimum value of the input array `x`.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int, Tuple[int, ...]]):
            axis or axes along which minima must be computed. If a tuple
            of unique integers, the minima are computed over multiple axes.
            If `None`, the min is computed over the entire array.
            Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the minima. If the min was computed over the
            entire array, a zero-dimensional array is returned. The returned
            array has the same data type as `x`.
    """
    return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis) 
def _search_over_axis(x, axis, keepdims, _reduction_fn):
    if not isinstance(x, dpt.usm_ndarray):
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
    nd = x.ndim
    if axis is None:
        axis = tuple(range(nd))
    elif isinstance(axis, int):
        axis = (axis,)
    else:
        raise TypeError(
            f"`axis` argument expected `int` or `None`, got {type(axis)}"
        )
    axis = normalize_axis_tuple(axis, nd, "axis")
    red_nd = len(axis)
    perm = [i for i in range(nd) if i not in axis] + list(axis)
    x_tmp = dpt.permute_dims(x, perm)
    res_shape = x_tmp.shape[: nd - red_nd]
    exec_q = x.sycl_queue
    res_dt = ti.default_device_index_type(exec_q.sycl_device)
    res_usm_type = x.usm_type
    if x.size == 0:
        raise ValueError("reduction does not support zero-size arrays")
    if red_nd == 0:
        return dpt.zeros(
            res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
        )
    res = dpt.empty(
        res_shape,
        dtype=res_dt,
        usm_type=res_usm_type,
        sycl_queue=exec_q,
    )
    hev, _ = _reduction_fn(
        src=x_tmp,
        trailing_dims_to_reduce=red_nd,
        dst=res,
        sycl_queue=exec_q,
    )
    if keepdims:
        res_shape = res_shape + (1,) * red_nd
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
        res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
    hev.wait()
    return res
[docs]def argmax(x, axis=None, keepdims=False):
    """argmax(x, axis=None, dtype=None, keepdims=False)
    Returns the indices of the maximum values of the input array `x` along a
    specified axis.
    When the maximum value occurs multiple times, the indices corresponding to
    the first occurrence are returned.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int]):
            axis along which to search. If `None`, returns the index of the
            maximum value of the flattened array.
            Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the indices of the first occurrence of the
            maximum values. If the entire array was searched, a
            zero-dimensional array is returned. The returned array has the
            default array index data type for the device of `x`.
    """
    return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis) 
[docs]def argmin(x, axis=None, keepdims=False):
    """argmin(x, axis=None, dtype=None, keepdims=False)
    Returns the indices of the minimum values of the input array `x` along a
    specified axis.
    When the minimum value occurs multiple times, the indices corresponding to
    the first occurrence are returned.
    Args:
        x (usm_ndarray):
            input array.
        axis (Optional[int]):
            axis along which to search. If `None`, returns the index of the
            minimum value of the flattened array.
            Default: `None`.
        keepdims (Optional[bool]):
            if `True`, the reduced axes (dimensions) are included in the result
            as singleton dimensions, so that the returned array remains
            compatible with the input arrays according to Array Broadcasting
            rules. Otherwise, if `False`, the reduced axes are not included in
            the returned array. Default: `False`.
    Returns:
        usm_ndarray:
            an array containing the indices of the first occurrence of the
            minimum values. If the entire array was searched, a
            zero-dimensional array is returned. The returned array has the
            default array index data type for the device of `x`.
    """
    return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis)