Source code for dpctl.tensor._statistical_functions

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from numpy.core.numeric import normalize_axis_tuple

import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_elementwise_impl as tei
import dpctl.tensor._tensor_impl as ti
import dpctl.tensor._tensor_reductions_impl as tri


def _var_impl(x, axis, correction, keepdims):
    nd = x.ndim
    if axis is None:
        axis = tuple(range(nd))
    if not isinstance(axis, (tuple, list)):
        axis = (axis,)
    axis = normalize_axis_tuple(axis, nd, "axis")
    perm = []
    nelems = 1
    for i in range(nd):
        if i not in axis:
            perm.append(i)
        else:
            nelems *= x.shape[i]
    red_nd = len(axis)
    perm = perm + list(axis)
    q = x.sycl_queue
    inp_dt = x.dtype
    res_dt = (
        inp_dt
        if inp_dt.kind == "f"
        else dpt.dtype(ti.default_device_fp_type(q))
    )
    res_usm_type = x.usm_type

    deps = []
    host_tasks_list = []
    if inp_dt != res_dt:
        buf = dpt.empty_like(x, dtype=res_dt)
        ht_e_buf, c_e1 = ti._copy_usm_ndarray_into_usm_ndarray(
            src=x, dst=buf, sycl_queue=q
        )
        deps.append(c_e1)
        host_tasks_list.append(ht_e_buf)
    else:
        buf = x
    # calculate mean
    buf2 = dpt.permute_dims(buf, perm)
    res_shape = buf2.shape[: nd - red_nd]
    # use keepdims=True path for later broadcasting
    if red_nd == 0:
        mean_ary = dpt.empty_like(buf)
        ht_e1, c_e2 = ti._copy_usm_ndarray_into_usm_ndarray(
            src=buf, dst=mean_ary, sycl_queue=q
        )
        deps.append(c_e2)
        host_tasks_list.append(ht_e1)
    else:
        mean_ary = dpt.empty(
            res_shape,
            dtype=res_dt,
            usm_type=res_usm_type,
            sycl_queue=q,
        )
        ht_e1, r_e1 = tri._sum_over_axis(
            src=buf2,
            trailing_dims_to_reduce=red_nd,
            dst=mean_ary,
            sycl_queue=q,
            depends=deps,
        )
        host_tasks_list.append(ht_e1)
        deps.append(r_e1)

        mean_ary_shape = res_shape + (1,) * red_nd
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
        mean_ary = dpt.permute_dims(
            dpt.reshape(mean_ary, mean_ary_shape), inv_perm
        )
    # divide in-place to get mean
    mean_ary_shape = mean_ary.shape
    nelems_ary = dpt.asarray(
        nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
    )
    if nelems_ary.shape != mean_ary_shape:
        nelems_ary = dpt.broadcast_to(nelems_ary, mean_ary_shape)
    ht_e2, d_e1 = tei._divide_inplace(
        lhs=mean_ary, rhs=nelems_ary, sycl_queue=q, depends=deps
    )
    host_tasks_list.append(ht_e2)
    # subtract mean from original array to get deviations
    dev_ary = dpt.empty_like(buf)
    if mean_ary_shape != buf.shape:
        mean_ary = dpt.broadcast_to(mean_ary, buf.shape)
    ht_e4, su_e = tei._subtract(
        src1=buf, src2=mean_ary, dst=dev_ary, sycl_queue=q, depends=[d_e1]
    )
    host_tasks_list.append(ht_e4)
    # square deviations
    ht_e5, sq_e = tei._square(
        src=dev_ary, dst=dev_ary, sycl_queue=q, depends=[su_e]
    )
    host_tasks_list.append(ht_e5)
    deps2 = []
    # take sum of squared deviations
    dev_ary2 = dpt.permute_dims(dev_ary, perm)
    if red_nd == 0:
        res = dev_ary
        deps2.append(sq_e)
    else:
        res = dpt.empty(
            res_shape,
            dtype=res_dt,
            usm_type=res_usm_type,
            sycl_queue=q,
        )
        ht_e6, r_e2 = tri._sum_over_axis(
            src=dev_ary2,
            trailing_dims_to_reduce=red_nd,
            dst=res,
            sycl_queue=q,
            depends=[sq_e],
        )
        host_tasks_list.append(ht_e6)
        deps2.append(r_e2)

        if keepdims:
            res_shape = res_shape + (1,) * red_nd
            inv_perm = sorted(range(nd), key=lambda d: perm[d])
            res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
    res_shape = res.shape
    # when nelems - correction <= 0, yield nans
    div = max(nelems - correction, 0)
    if not div:
        div = dpt.nan
    div_ary = dpt.asarray(
        div, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
    )
    # divide in-place again
    if div_ary.shape != res_shape:
        div_ary = dpt.broadcast_to(div_ary, res.shape)
    ht_e7, d_e2 = tei._divide_inplace(
        lhs=res, rhs=div_ary, sycl_queue=q, depends=deps2
    )
    host_tasks_list.append(ht_e7)
    return res, [d_e2], host_tasks_list


[docs]def mean(x, axis=None, keepdims=False): """mean(x, axis=None, keepdims=False) Calculates the arithmetic mean of elements in the input array `x`. Args: x (usm_ndarray): input array. axis (Optional[int, Tuple[int, ...]]): axis or axes along which the arithmetic means must be computed. If a tuple of unique integers, the means are computed over multiple axes. If `None`, the mean is computed over the entire array. Default: `None`. keepdims (Optional[bool]): if `True`, the reduced axes (dimensions) are included in the result as singleton dimensions, so that the returned array remains compatible with the input array according to Array Broadcasting rules. Otherwise, if `False`, the reduced axes are not included in the returned array. Default: `False`. Returns: usm_ndarray: an array containing the arithmetic means. If the mean was computed over the entire array, a zero-dimensional array is returned. If `x` has a floating-point data type, the returned array will have the same data type as `x`. If `x` has a boolean or integral data type, the returned array will have the default floating point data type for the device where input array `x` is allocated. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") nd = x.ndim if axis is None: axis = tuple(range(nd)) if not isinstance(axis, (tuple, list)): axis = (axis,) axis = normalize_axis_tuple(axis, nd, "axis") perm = [] nelems = 1 for i in range(nd): if i not in axis: perm.append(i) else: nelems *= x.shape[i] sum_nd = len(axis) perm = perm + list(axis) arr2 = dpt.permute_dims(x, perm) res_shape = arr2.shape[: nd - sum_nd] q = x.sycl_queue inp_dt = x.dtype res_dt = ( x.dtype if x.dtype.kind in "fc" else dpt.dtype(ti.default_device_fp_type(q)) ) res_usm_type = x.usm_type if sum_nd == 0: return dpt.astype(x, res_dt, copy=True) s_e = [] host_tasks_list = [] if tri._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q): res = dpt.empty( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) ht_e1, r_e = tri._sum_over_axis( src=arr2, trailing_dims_to_reduce=sum_nd, dst=res, sycl_queue=q ) host_tasks_list.append(ht_e1) s_e.append(r_e) else: tmp = dpt.empty( arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( src=arr2, dst=tmp, sycl_queue=q ) host_tasks_list.append(ht_e_cpy) res = dpt.empty( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) ht_e_red, r_e = tri._sum_over_axis( src=tmp, trailing_dims_to_reduce=sum_nd, dst=res, sycl_queue=q, depends=[cpy_e], ) host_tasks_list.append(ht_e_red) s_e.append(r_e) if keepdims: res_shape = res_shape + (1,) * sum_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) res_shape = res.shape # in-place divide den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt nelems_arr = dpt.asarray( nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q ) if nelems_arr.shape != res_shape: nelems_arr = dpt.broadcast_to(nelems_arr, res_shape) ht_e2, _ = tei._divide_inplace( lhs=res, rhs=nelems_arr, sycl_queue=q, depends=s_e ) host_tasks_list.append(ht_e2) dpctl.SyclEvent.wait_for(host_tasks_list) return res
[docs]def var(x, axis=None, correction=0.0, keepdims=False): """var(x, axis=None, correction=0.0, keepdims=False) Calculates the variance of elements in the input array `x`. Args: x (usm_ndarray): input array. axis (Optional[int, Tuple[int, ...]]): axis or axes along which the variances must be computed. If a tuple of unique integers, the variances are computed over multiple axes. If `None`, the variance is computed over the entire array. Default: `None`. correction (Optional[float, int]): degrees of freedom adjustment. The divisor used in calculating the variance is `N - correction`, where `N` corresponds to the total number of elements over which the variance is calculated. Default: `0.0`. keepdims (Optional[bool]): if `True`, the reduced axes (dimensions) are included in the result as singleton dimensions, so that the returned array remains compatible with the input array according to Array Broadcasting rules. Otherwise, if `False`, the reduced axes are not included in the returned array. Default: `False`. Returns: usm_ndarray: an array containing the variances. If the variance was computed over the entire array, a zero-dimensional array is returned. If `x` has a real-valued floating-point data type, the returned array will have the same data type as `x`. If `x` has a boolean or integral data type, the returned array will have the default floating point data type for the device where input array `x` is allocated. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") if not isinstance(correction, (int, float)): raise TypeError( "Expected a Python integer or float for `correction`, got" f"{type(x)}" ) if x.dtype.kind == "c": raise ValueError("`var` does not support complex types") res, _, host_tasks_list = _var_impl(x, axis, correction, keepdims) dpctl.SyclEvent.wait_for(host_tasks_list) return res
[docs]def std(x, axis=None, correction=0.0, keepdims=False): """std(x, axis=None, correction=0.0, keepdims=False) Calculates the standard deviation of elements in the input array `x`. Args: x (usm_ndarray): input array. axis (Optional[int, Tuple[int, ...]]): axis or axes along which the standard deviations must be computed. If a tuple of unique integers, the standard deviations are computed over multiple axes. If `None`, the standard deviation is computed over the entire array. Default: `None`. correction (Optional[float, int]): degrees of freedom adjustment. The divisor used in calculating the standard deviation is `N - correction`, where `N` corresponds to the total number of elements over which the standard deviation is calculated. Default: `0.0`. keepdims (Optional[bool]): if `True`, the reduced axes (dimensions) are included in the result as singleton dimensions, so that the returned array remains compatible with the input array according to Array Broadcasting rules. Otherwise, if `False`, the reduced axes are not included in the returned array. Default: `False`. Returns: usm_ndarray: an array containing the standard deviations. If the standard deviation was computed over the entire array, a zero-dimensional array is returned. If `x` has a real-valued floating-point data type, the returned array will have the same data type as `x`. If `x` has a boolean or integral data type, the returned array will have the default floating point data type for the device where input array `x` is allocated. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") if not isinstance(correction, (int, float)): raise TypeError( "Expected a Python integer or float for `correction`," f"got {type(x)}" ) if x.dtype.kind == "c": raise ValueError("`std` does not support complex types") res, deps, host_tasks_list = _var_impl(x, axis, correction, keepdims) ht_ev, _ = tei._sqrt( src=res, dst=res, sycl_queue=res.sycl_queue, depends=deps ) host_tasks_list.append(ht_ev) dpctl.SyclEvent.wait_for(host_tasks_list) return res