Source code for dpctl.tensor._searchsorted

from typing import Literal, Union

import dpctl
import dpctl.utils as du

from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
    default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import isdtype, result_type
from ._usmarray import usm_ndarray


[docs]def searchsorted( x1: usm_ndarray, x2: usm_ndarray, /, *, side: Literal["left", "right"] = "left", sorter: Union[usm_ndarray, None] = None, ) -> usm_ndarray: """searchsorted(x1, x2, side='left', sorter=None) Finds the indices into `x1` such that, if the corresponding elements in `x2` were inserted before the indices, the order of `x1`, when sorted in ascending order, would be preserved. Args: x1 (usm_ndarray): input array. Must be a one-dimensional array. If `sorter` is `None`, must be sorted in ascending order; otherwise, `sorter` must be an array of indices that sort `x1` in ascending order. x2 (usm_ndarray): array containing search values. side (Literal["left", "right]): argument controlling which index is returned if a value lands exactly on an edge. If `x2` is an array of rank `N` where `v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the return array `ret` contains the position `i` such that if `side="left"`, it is the first index such that `x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size` if `v > x1[-1]`; and if `side="right"`, it is the first position `i` such that `x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size` if `v >= x1[-1]`. Default: `"left"`. sorter (Optional[usm_ndarray]): array of indices that sort `x1` in ascending order. The array must have the same shape as `x1` and have an integral data type. Out of bound index values of `sorter` array are treated using `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. Default: `None`. """ if not isinstance(x1, usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") if not isinstance(x2, usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") if sorter is not None and not isinstance(sorter, usm_ndarray): raise TypeError( f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}" ) if side not in ["left", "right"]: raise ValueError( "Unrecognized value of 'side' keyword argument. " "Expected either 'left' or 'right'" ) if sorter is None: q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue]) else: q = du.get_execution_queue( [x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue] ) if q is None: raise du.ExecutionPlacementError( "Execution placement can not be unambiguously " "inferred from input arguments." ) if x1.ndim != 1: raise ValueError("First argument array must be one-dimensional") x1_dt = x1.dtype x2_dt = x2.dtype host_evs = [] ev = dpctl.SyclEvent() if sorter is not None: if not isdtype(sorter.dtype, "integral"): raise ValueError( f"Sorter array must have integral data type, got {sorter.dtype}" ) if x1.shape != sorter.shape: raise ValueError( "Sorter array must be one-dimension with the same " "shape as the first argument array" ) res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q) ind = (sorter,) axis = 0 wrap_out_of_bound_indices_mode = 0 ht_ev, ev = ti_take( x1, ind, res, axis, wrap_out_of_bound_indices_mode, sycl_queue=q, depends=[ ev, ], ) x1 = res host_evs.append(ht_ev) if x1_dt != x2_dt: dt = result_type(x1, x2) if x1_dt != dt: x1_buf = _empty_like_orderK(x1, dt) ht_ev, ev = ti_copy( src=x1, dst=x1_buf, sycl_queue=q, depends=[ ev, ], ) host_evs.append(ht_ev) x1 = x1_buf if x2_dt != dt: x2_buf = _empty_like_orderK(x2, dt) ht_ev, ev = ti_copy( src=x2, dst=x2_buf, sycl_queue=q, depends=[ ev, ], ) host_evs.append(ht_ev) x2 = x2_buf dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) index_dt = ti_default_device_index_type(q) dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) if side == "left": ht_ev, _ = _searchsorted_left( hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=[ ev, ], ) else: ht_ev, _ = _searchsorted_right( hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=[ ev, ], ) host_evs.append(ht_ev) dpctl.SyclEvent.wait_for(host_evs) return dst