# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import NamedTuple
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as du
from ._numpy_helper import normalize_axis_index
from ._tensor_sorting_impl import (
_argsort_ascending,
_argsort_descending,
_radix_argsort_ascending,
_radix_argsort_descending,
_radix_sort_ascending,
_radix_sort_descending,
_radix_sort_dtype_supported,
_sort_ascending,
_sort_descending,
_topk,
)
__all__ = ["sort", "argsort"]
def _get_mergesort_impl_fn(descending):
return _sort_descending if descending else _sort_ascending
def _get_radixsort_impl_fn(descending):
return _radix_sort_descending if descending else _radix_sort_ascending
[docs]def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
"""sort(x, axis=-1, descending=False, stable=True)
Returns a sorted copy of an input array `x`.
Args:
x (usm_ndarray):
input array.
axis (Optional[int]):
axis along which to sort. If set to `-1`, the function
must sort along the last axis. Default: `-1`.
descending (Optional[bool]):
sort order. If `True`, the array must be sorted in descending
order (by value). If `False`, the array must be sorted in
ascending order (by value). Default: `False`.
stable (Optional[bool]):
sort stability. If `True`, the returned array must maintain the
relative order of `x` values which compare as equal. If `False`,
the returned array may or may not maintain the relative order of
`x` values which compare as equal. Default: `True`.
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
Sorting algorithm. The default is `"stable"`, which uses parallel
merge-sort or parallel radix-sort algorithms depending on the
array data type.
Returns:
usm_ndarray:
a sorted array. The returned array has the same data type and
the same shape as the input array `x`.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
)
nd = x.ndim
if nd == 0:
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
return dpt.copy(x, order="C")
else:
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
a1 = axis + 1
if a1 == nd:
perm = list(range(nd))
arr = x
else:
perm = [i for i in range(nd) if i != axis] + [
axis,
]
arr = dpt.permute_dims(x, perm)
if kind is None:
kind = "stable"
if not isinstance(kind, str) or kind not in [
"stable",
"radixsort",
"mergesort",
]:
raise ValueError(
"Unsupported kind value. Expected 'stable', 'mergesort', "
f"or 'radixsort', but got '{kind}'"
)
if kind == "mergesort":
impl_fn = _get_mergesort_impl_fn(descending)
elif kind == "radixsort":
if _radix_sort_dtype_supported(x.dtype.num):
impl_fn = _get_radixsort_impl_fn(descending)
else:
raise ValueError(f"Radix sort is not supported for {x.dtype}")
else:
dt = x.dtype
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
impl_fn = _get_radixsort_impl_fn(descending)
else:
impl_fn = _get_mergesort_impl_fn(descending)
exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
if arr.flags.c_contiguous:
res = dpt.empty_like(arr, order="C")
ht_ev, impl_ev = impl_fn(
src=arr,
trailing_dims_to_sort=1,
dst=res,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, impl_ev)
else:
tmp = dpt.empty_like(arr, order="C")
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, copy_ev)
res = dpt.empty_like(arr, order="C")
ht_ev, impl_ev = impl_fn(
src=tmp,
trailing_dims_to_sort=1,
dst=res,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, impl_ev)
if a1 != nd:
inv_perm = sorted(range(nd), key=lambda d: perm[d])
res = dpt.permute_dims(res, inv_perm)
return res
def _get_mergeargsort_impl_fn(descending):
return _argsort_descending if descending else _argsort_ascending
def _get_radixargsort_impl_fn(descending):
return _radix_argsort_descending if descending else _radix_argsort_ascending
[docs]def argsort(x, axis=-1, descending=False, stable=True, kind=None):
"""argsort(x, axis=-1, descending=False, stable=True)
Returns the indices that sort an array `x` along a specified axis.
Args:
x (usm_ndarray):
input array.
axis (Optional[int]):
axis along which to sort. If set to `-1`, the function
must sort along the last axis. Default: `-1`.
descending (Optional[bool]):
sort order. If `True`, the array must be sorted in descending
order (by value). If `False`, the array must be sorted in
ascending order (by value). Default: `False`.
stable (Optional[bool]):
sort stability. If `True`, the returned array must maintain the
relative order of `x` values which compare as equal. If `False`,
the returned array may or may not maintain the relative order of
`x` values which compare as equal. Default: `True`.
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
Sorting algorithm. The default is `"stable"`, which uses parallel
merge-sort or parallel radix-sort algorithms depending on the
array data type.
Returns:
usm_ndarray:
an array of indices. The returned array has the same shape as
the input array `x`. The return array has default array index
data type.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
)
nd = x.ndim
if nd == 0:
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
return dpt.zeros_like(
x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
)
else:
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
a1 = axis + 1
if a1 == nd:
perm = list(range(nd))
arr = x
else:
perm = [i for i in range(nd) if i != axis] + [
axis,
]
arr = dpt.permute_dims(x, perm)
if kind is None:
kind = "stable"
if not isinstance(kind, str) or kind not in [
"stable",
"radixsort",
"mergesort",
]:
raise ValueError(
"Unsupported kind value. Expected 'stable', 'mergesort', "
f"or 'radixsort', but got '{kind}'"
)
if kind == "mergesort":
impl_fn = _get_mergeargsort_impl_fn(descending)
elif kind == "radixsort":
if _radix_sort_dtype_supported(x.dtype.num):
impl_fn = _get_radixargsort_impl_fn(descending)
else:
raise ValueError(f"Radix sort is not supported for {x.dtype}")
else:
dt = x.dtype
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
impl_fn = _get_radixargsort_impl_fn(descending)
else:
impl_fn = _get_mergeargsort_impl_fn(descending)
exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
index_dt = ti.default_device_index_type(exec_q)
if arr.flags.c_contiguous:
res = dpt.empty_like(arr, dtype=index_dt, order="C")
ht_ev, impl_ev = impl_fn(
src=arr,
trailing_dims_to_sort=1,
dst=res,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, impl_ev)
else:
tmp = dpt.empty_like(arr, order="C")
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, copy_ev)
res = dpt.empty_like(arr, dtype=index_dt, order="C")
ht_ev, impl_ev = impl_fn(
src=tmp,
trailing_dims_to_sort=1,
dst=res,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, impl_ev)
if a1 != nd:
inv_perm = sorted(range(nd), key=lambda d: perm[d])
res = dpt.permute_dims(res, inv_perm)
return res
def _get_top_k_largest(mode):
modes = {"largest": True, "smallest": False}
try:
return modes[mode]
except KeyError:
raise ValueError(
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
)
class TopKResult(NamedTuple):
values: dpt.usm_ndarray
indices: dpt.usm_ndarray
[docs]def top_k(x, k, /, *, axis=None, mode="largest"):
"""top_k(x, k, axis=None, mode="largest")
Returns the `k` largest or smallest values and their indices in the input
array `x` along the specified axis `axis`.
Args:
x (usm_ndarray):
input array.
k (int):
number of elements to find. Must be a positive integer value.
axis (Optional[int]):
axis along which to search. If `None`, the search will be performed
over the flattened array. Default: ``None``.
mode (Literal["largest", "smallest"]):
search mode. Must be one of the following modes:
- `"largest"`: return the `k` largest elements.
- `"smallest"`: return the `k` smallest elements.
Default: `"largest"`.
Returns:
tuple[usm_ndarray, usm_ndarray]
a namedtuple `(values, indices)` whose
* first element `values` will be an array containing the `k`
largest or smallest elements of `x`. The array has the same data
type as `x`. If `axis` was `None`, `values` will be a
one-dimensional array with shape `(k,)` and otherwise, `values`
will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
* second element `indices` will be an array containing indices of
`x` that result in `values`. The array will have the same shape
as `values` and will have the default array index data type.
"""
largest = _get_top_k_largest(mode)
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
)
k = operator.index(k)
if k < 0:
raise ValueError("`k` must be a positive integer value")
nd = x.ndim
if axis is None:
sz = x.size
if nd == 0:
if k > 1:
raise ValueError(f"`k`={k} is out of bounds 1")
return TopKResult(
dpt.copy(x, order="C"),
dpt.zeros_like(
x, dtype=ti.default_device_index_type(x.sycl_queue)
),
)
arr = x
n_search_dims = None
res_sh = k
else:
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
sz = x.shape[axis]
a1 = axis + 1
if a1 == nd:
perm = list(range(nd))
arr = x
else:
perm = [i for i in range(nd) if i != axis] + [
axis,
]
arr = dpt.permute_dims(x, perm)
n_search_dims = 1
res_sh = arr.shape[: nd - 1] + (k,)
if k > sz:
raise ValueError(f"`k`={k} is out of bounds {sz}")
exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
res_usm_type = arr.usm_type
if arr.flags.c_contiguous:
vals = dpt.empty(
res_sh,
dtype=arr.dtype,
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
inds = dpt.empty(
res_sh,
dtype=ti.default_device_index_type(exec_q),
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
ht_ev, impl_ev = _topk(
src=arr,
trailing_dims_to_search=n_search_dims,
k=k,
largest=largest,
vals=vals,
inds=inds,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, impl_ev)
else:
tmp = dpt.empty_like(arr, order="C")
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, copy_ev)
vals = dpt.empty(
res_sh,
dtype=arr.dtype,
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
inds = dpt.empty(
res_sh,
dtype=ti.default_device_index_type(exec_q),
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
ht_ev, impl_ev = _topk(
src=tmp,
trailing_dims_to_search=n_search_dims,
k=k,
largest=largest,
vals=vals,
inds=inds,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, impl_ev)
if axis is not None and a1 != nd:
inv_perm = sorted(range(nd), key=lambda d: perm[d])
vals = dpt.permute_dims(vals, inv_perm)
inds = dpt.permute_dims(inds, inv_perm)
return TopKResult(vals, inds)