Source code for dpctl.tensor._sorting

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as du

from ._numpy_helper import normalize_axis_index
from ._tensor_sorting_impl import (
    _argsort_ascending,
    _argsort_descending,
    _sort_ascending,
    _sort_descending,
)
from ._tensor_sorting_radix_impl import (
    _radix_argsort_ascending,
    _radix_argsort_descending,
    _radix_sort_ascending,
    _radix_sort_descending,
    _radix_sort_dtype_supported,
)

__all__ = ["sort", "argsort"]


def _get_mergesort_impl_fn(descending):
    return _sort_descending if descending else _sort_ascending


def _get_radixsort_impl_fn(descending):
    return _radix_sort_descending if descending else _radix_sort_ascending


[docs]def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None): """sort(x, axis=-1, descending=False, stable=True) Returns a sorted copy of an input array `x`. Args: x (usm_ndarray): input array. axis (Optional[int]): axis along which to sort. If set to `-1`, the function must sort along the last axis. Default: `-1`. descending (Optional[bool]): sort order. If `True`, the array must be sorted in descending order (by value). If `False`, the array must be sorted in ascending order (by value). Default: `False`. stable (Optional[bool]): sort stability. If `True`, the returned array must maintain the relative order of `x` values which compare as equal. If `False`, the returned array may or may not maintain the relative order of `x` values which compare as equal. Default: `True`. kind (Optional[Literal["stable", "mergesort", "radixsort"]]): Sorting algorithm. The default is `"stable"`, which uses parallel merge-sort or parallel radix-sort algorithms depending on the array data type. Returns: usm_ndarray: a sorted array. The returned array has the same data type and the same shape as the input array `x`. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" ) nd = x.ndim if nd == 0: axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") return dpt.copy(x, order="C") else: axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") a1 = axis + 1 if a1 == nd: perm = list(range(nd)) arr = x else: perm = [i for i in range(nd) if i != axis] + [ axis, ] arr = dpt.permute_dims(x, perm) if kind is None: kind = "stable" if not isinstance(kind, str) or kind not in [ "stable", "radixsort", "mergesort", ]: raise ValueError( "Unsupported kind value. Expected 'stable', 'mergesort', " f"or 'radixsort', but got '{kind}'" ) if kind == "mergesort": impl_fn = _get_mergesort_impl_fn(descending) elif kind == "radixsort": if _radix_sort_dtype_supported(x.dtype.num): impl_fn = _get_radixsort_impl_fn(descending) else: raise ValueError(f"Radix sort is not supported for {x.dtype}") else: dt = x.dtype if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: impl_fn = _get_radixsort_impl_fn(descending) else: impl_fn = _get_mergesort_impl_fn(descending) exec_q = x.sycl_queue _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events if arr.flags.c_contiguous: res = dpt.empty_like(arr, order="C") ht_ev, impl_ev = impl_fn( src=arr, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, impl_ev) else: tmp = dpt.empty_like(arr, order="C") ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) res = dpt.empty_like(arr, order="C") ht_ev, impl_ev = impl_fn( src=tmp, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, impl_ev) if a1 != nd: inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(res, inv_perm) return res
def _get_mergeargsort_impl_fn(descending): return _argsort_descending if descending else _argsort_ascending def _get_radixargsort_impl_fn(descending): return _radix_argsort_descending if descending else _radix_argsort_ascending
[docs]def argsort(x, axis=-1, descending=False, stable=True, kind=None): """argsort(x, axis=-1, descending=False, stable=True) Returns the indices that sort an array `x` along a specified axis. Args: x (usm_ndarray): input array. axis (Optional[int]): axis along which to sort. If set to `-1`, the function must sort along the last axis. Default: `-1`. descending (Optional[bool]): sort order. If `True`, the array must be sorted in descending order (by value). If `False`, the array must be sorted in ascending order (by value). Default: `False`. stable (Optional[bool]): sort stability. If `True`, the returned array must maintain the relative order of `x` values which compare as equal. If `False`, the returned array may or may not maintain the relative order of `x` values which compare as equal. Default: `True`. kind (Optional[Literal["stable", "mergesort", "radixsort"]]): Sorting algorithm. The default is `"stable"`, which uses parallel merge-sort or parallel radix-sort algorithms depending on the array data type. Returns: usm_ndarray: an array of indices. The returned array has the same shape as the input array `x`. The return array has default array index data type. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" ) nd = x.ndim if nd == 0: axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") return dpt.zeros_like( x, dtype=ti.default_device_index_type(x.sycl_queue), order="C" ) else: axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") a1 = axis + 1 if a1 == nd: perm = list(range(nd)) arr = x else: perm = [i for i in range(nd) if i != axis] + [ axis, ] arr = dpt.permute_dims(x, perm) if kind is None: kind = "stable" if not isinstance(kind, str) or kind not in [ "stable", "radixsort", "mergesort", ]: raise ValueError( "Unsupported kind value. Expected 'stable', 'mergesort', " f"or 'radixsort', but got '{kind}'" ) if kind == "mergesort": impl_fn = _get_mergeargsort_impl_fn(descending) elif kind == "radixsort": if _radix_sort_dtype_supported(x.dtype.num): impl_fn = _get_radixargsort_impl_fn(descending) else: raise ValueError(f"Radix sort is not supported for {x.dtype}") else: dt = x.dtype if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: impl_fn = _get_radixargsort_impl_fn(descending) else: impl_fn = _get_mergeargsort_impl_fn(descending) exec_q = x.sycl_queue _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events index_dt = ti.default_device_index_type(exec_q) if arr.flags.c_contiguous: res = dpt.empty_like(arr, dtype=index_dt, order="C") ht_ev, impl_ev = impl_fn( src=arr, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, impl_ev) else: tmp = dpt.empty_like(arr, order="C") ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) res = dpt.empty_like(arr, dtype=index_dt, order="C") ht_ev, impl_ev = impl_fn( src=tmp, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, impl_ev) if a1 != nd: inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(res, inv_perm) return res