Source code for dpctl.tensor._set_functions

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from typing import NamedTuple

import dpctl.tensor as dpt
import dpctl.utils as du

from ._tensor_elementwise_impl import _not_equal, _subtract
from ._tensor_impl import (
    _copy_usm_ndarray_into_usm_ndarray,
    _extract,
    _full_usm_ndarray,
    _linspace_step,
    _take,
    default_device_index_type,
    mask_positions,
)
from ._tensor_sorting_impl import (
    _argsort_ascending,
    _searchsorted_left,
    _sort_ascending,
)

__all__ = [
    "unique_values",
    "unique_counts",
    "unique_inverse",
    "unique_all",
    "UniqueAllResult",
    "UniqueCountsResult",
    "UniqueInverseResult",
]


class UniqueAllResult(NamedTuple):
    values: dpt.usm_ndarray
    indices: dpt.usm_ndarray
    inverse_indices: dpt.usm_ndarray
    counts: dpt.usm_ndarray


class UniqueCountsResult(NamedTuple):
    values: dpt.usm_ndarray
    counts: dpt.usm_ndarray


class UniqueInverseResult(NamedTuple):
    values: dpt.usm_ndarray
    inverse_indices: dpt.usm_ndarray


[docs]def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray: """unique_values(x) Returns the unique elements of an input array `x`. Args: x (usm_ndarray): input array. Inputs with more than one dimension are flattened. Returns: usm_ndarray an array containing the set of unique elements in `x`. The returned array has the same data type as `x`. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") array_api_dev = x.device exec_q = array_api_dev.sycl_queue if x.ndim == 1: fx = x else: fx = dpt.reshape(x, (x.size,), order="C") if fx.size == 0: return fx s = dpt.empty_like(fx, order="C") _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events if fx.flags.c_contiguous: ht_ev, sort_ev = _sort_ascending( src=fx, trailing_dims_to_sort=1, dst=s, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, sort_ev) else: tmp = dpt.empty_like(fx, order="C") ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) ht_ev, sort_ev = _sort_ascending( src=tmp, trailing_dims_to_sort=1, dst=s, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, sort_ev) unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q) ht_ev, uneq_ev = _not_equal( src1=s[:-1], src2=s[1:], dst=unique_mask[1:], sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, uneq_ev) # writing into new allocation, no dependencies ht_ev, one_ev = _full_usm_ndarray( fill_value=True, dst=unique_mask[0], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, one_ev) cumsum = dpt.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q) # synchronizing call n_uniques = mask_positions( unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] ) if n_uniques == fx.size: return s unique_vals = dpt.empty( n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q ) ht_ev, ex_e = _extract( src=s, cumsum=cumsum, axis_start=0, axis_end=1, dst=unique_vals, sycl_queue=exec_q, ) _manager.add_event_pair(ht_ev, ex_e) return unique_vals
[docs]def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult: """unique_counts(x) Returns the unique elements of an input array `x` and the corresponding counts for each unique element in `x`. Args: x (usm_ndarray): input array. Inputs with more than one dimension are flattened. Returns: tuple[usm_ndarray, usm_ndarray] a namedtuple `(values, counts)` whose * first element is the field name `values` and is an array containing the unique elements of `x`. This array has the same data type as `x`. * second element has the field name `counts` and is an array containing the number of times each unique element occurs in `x`. This array has the same shape as `values` and has the default array index data type. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") array_api_dev = x.device exec_q = array_api_dev.sycl_queue x_usm_type = x.usm_type if x.ndim == 1: fx = x else: fx = dpt.reshape(x, (x.size,), order="C") ind_dt = default_device_index_type(exec_q) if fx.size == 0: return UniqueCountsResult(fx, dpt.empty_like(fx, dtype=ind_dt)) s = dpt.empty_like(fx, order="C") _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events if fx.flags.c_contiguous: ht_ev, sort_ev = _sort_ascending( src=fx, trailing_dims_to_sort=1, dst=s, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, sort_ev) else: tmp = dpt.empty_like(fx, order="C") ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) ht_ev, sort_ev = _sort_ascending( src=tmp, dst=s, trailing_dims_to_sort=1, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, sort_ev) unique_mask = dpt.empty(s.shape, dtype="?", sycl_queue=exec_q) ht_ev, uneq_ev = _not_equal( src1=s[:-1], src2=s[1:], dst=unique_mask[1:], sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, uneq_ev) # no dependency, since we write into new allocation ht_ev, one_ev = _full_usm_ndarray( fill_value=True, dst=unique_mask[0], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, one_ev) cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q) # synchronizing call n_uniques = mask_positions( unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] ) if n_uniques == fx.size: return UniqueCountsResult( s, dpt.ones( n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q ), ) unique_vals = dpt.empty( n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q ) # populate unique values ht_ev, ex_e = _extract( src=s, cumsum=cumsum, axis_start=0, axis_end=1, dst=unique_vals, sycl_queue=exec_q, ) _manager.add_event_pair(ht_ev, ex_e) unique_counts = dpt.empty( n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q ) idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) # writing into new allocation, no dependency ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) _manager.add_event_pair(ht_ev, id_ev) ht_ev, extr_ev = _extract( src=idx, cumsum=cumsum, axis_start=0, axis_end=1, dst=unique_counts[:-1], sycl_queue=exec_q, depends=[id_ev], ) _manager.add_event_pair(ht_ev, extr_ev) # no dependency, writing into disjoint segmenent of new allocation ht_ev, set_ev = _full_usm_ndarray( x.size, dst=unique_counts[-1], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, set_ev) _counts = dpt.empty_like(unique_counts[1:]) ht_ev, sub_ev = _subtract( src1=unique_counts[1:], src2=unique_counts[:-1], dst=_counts, sycl_queue=exec_q, depends=[set_ev, extr_ev], ) _manager.add_event_pair(ht_ev, sub_ev) return UniqueCountsResult(unique_vals, _counts)
[docs]def unique_inverse(x): """unique_inverse Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct `x`. Args: x (usm_ndarray): input array. Inputs with more than one dimension are flattened. Returns: tuple[usm_ndarray, usm_ndarray] a namedtuple `(values, inverse_indices)` whose * first element has the field name `values` and is an array containing the unique elements of `x`. The array has the same data type as `x`. * second element has the field name `inverse_indices` and is an array containing the indices of values that reconstruct `x`. The array has the same shape as `x` and has the default array index data type. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") array_api_dev = x.device exec_q = array_api_dev.sycl_queue x_usm_type = x.usm_type ind_dt = default_device_index_type(exec_q) if x.ndim == 1: fx = x else: fx = dpt.reshape(x, (x.size,), order="C") sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C") unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C") if fx.size == 0: return UniqueInverseResult(fx, dpt.reshape(unsorting_ids, x.shape)) _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events if fx.flags.c_contiguous: ht_ev, sort_ev = _argsort_ascending( src=fx, trailing_dims_to_sort=1, dst=sorting_ids, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, sort_ev) else: tmp = dpt.empty_like(fx, order="C") ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) ht_ev, sort_ev = _argsort_ascending( src=tmp, trailing_dims_to_sort=1, dst=sorting_ids, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, sort_ev) ht_ev, argsort_ev = _argsort_ascending( src=sorting_ids, trailing_dims_to_sort=1, dst=unsorting_ids, sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, argsort_ev) s = dpt.empty_like(fx) # s = fx[sorting_ids] ht_ev, take_ev = _take( src=fx, ind=(sorting_ids,), dst=s, axis_start=0, mode=0, sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, take_ev) unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q) ht_ev, uneq_ev = _not_equal( src1=s[:-1], src2=s[1:], dst=unique_mask[1:], sycl_queue=exec_q, depends=[take_ev], ) _manager.add_event_pair(ht_ev, uneq_ev) # no dependency ht_ev, one_ev = _full_usm_ndarray( fill_value=True, dst=unique_mask[0], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, one_ev) cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q) # synchronizing call n_uniques = mask_positions( unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] ) if n_uniques == fx.size: return UniqueInverseResult(s, dpt.reshape(unsorting_ids, x.shape)) unique_vals = dpt.empty( n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q ) ht_ev, uv_ev = _extract( src=s, cumsum=cumsum, axis_start=0, axis_end=1, dst=unique_vals, sycl_queue=exec_q, ) _manager.add_event_pair(ht_ev, uv_ev) cum_unique_counts = dpt.empty( n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q ) idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) _manager.add_event_pair(ht_ev, id_ev) ht_ev, extr_ev = _extract( src=idx, cumsum=cumsum, axis_start=0, axis_end=1, dst=cum_unique_counts[:-1], sycl_queue=exec_q, depends=[id_ev], ) _manager.add_event_pair(ht_ev, extr_ev) ht_ev, set_ev = _full_usm_ndarray( x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, set_ev) _counts = dpt.empty_like(cum_unique_counts[1:]) ht_ev, sub_ev = _subtract( src1=cum_unique_counts[1:], src2=cum_unique_counts[:-1], dst=_counts, sycl_queue=exec_q, depends=[set_ev, extr_ev], ) _manager.add_event_pair(ht_ev, sub_ev) inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, positions=inv, sycl_queue=exec_q, depends=[ uv_ev, ], ) _manager.add_event_pair(ht_ev, ssl_ev) return UniqueInverseResult(unique_vals, inv)
[docs]def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: """unique_all(x) Returns the unique elements of an input array `x`, the first occurring indices for each unique element in `x`, the indices from the set of unique elements that reconstruct `x`, and the corresponding counts for each unique element in `x`. Args: x (usm_ndarray): input array. Inputs with more than one dimension are flattened. Returns: tuple[usm_ndarray, usm_ndarray, usm_ndarray, usm_ndarray] a namedtuple `(values, indices, inverse_indices, counts)` whose * first element has the field name `values` and is an array containing the unique elements of `x`. The array has the same data type as `x`. * second element has the field name `indices` and is an array the indices (of first occurrences) of `x` that result in `values`. The array has the same shape as `values` and has the default array index data type. * third element has the field name `inverse_indices` and is an array containing the indices of values that reconstruct `x`. The array has the same shape as `x` and has the default array index data type. * fourth element has the field name `counts` and is an array containing the number of times each unique element occurs in `x`. This array has the same shape as `values` and has the default array index data type. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") array_api_dev = x.device exec_q = array_api_dev.sycl_queue x_usm_type = x.usm_type ind_dt = default_device_index_type(exec_q) if x.ndim == 1: fx = x else: fx = dpt.reshape(x, (x.size,), order="C") sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C") unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C") if fx.size == 0: # original array contains no data # so it can be safely returned as values return UniqueAllResult( fx, sorting_ids, dpt.reshape(unsorting_ids, x.shape), dpt.empty_like(fx, dtype=ind_dt), ) _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events if fx.flags.c_contiguous: ht_ev, sort_ev = _argsort_ascending( src=fx, trailing_dims_to_sort=1, dst=sorting_ids, sycl_queue=exec_q, depends=dep_evs, ) _manager.add_event_pair(ht_ev, sort_ev) else: tmp = dpt.empty_like(fx, order="C") ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs ) _manager.add_event_pair(ht_ev, copy_ev) ht_ev, sort_ev = _argsort_ascending( src=tmp, trailing_dims_to_sort=1, dst=sorting_ids, sycl_queue=exec_q, depends=[copy_ev], ) _manager.add_event_pair(ht_ev, sort_ev) ht_ev, args_ev = _argsort_ascending( src=sorting_ids, trailing_dims_to_sort=1, dst=unsorting_ids, sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, args_ev) s = dpt.empty_like(fx) # s = fx[sorting_ids] ht_ev, take_ev = _take( src=fx, ind=(sorting_ids,), dst=s, axis_start=0, mode=0, sycl_queue=exec_q, depends=[sort_ev], ) _manager.add_event_pair(ht_ev, take_ev) unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q) ht_ev, uneq_ev = _not_equal( src1=s[:-1], src2=s[1:], dst=unique_mask[1:], sycl_queue=exec_q, depends=[take_ev], ) _manager.add_event_pair(ht_ev, uneq_ev) ht_ev, one_ev = _full_usm_ndarray( fill_value=True, dst=unique_mask[0], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, one_ev) cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q) # synchronizing call n_uniques = mask_positions( unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] ) if n_uniques == fx.size: _counts = dpt.ones( n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q ) return UniqueAllResult( s, sorting_ids, dpt.reshape(unsorting_ids, x.shape), _counts, ) unique_vals = dpt.empty( n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q ) ht_ev, uv_ev = _extract( src=s, cumsum=cumsum, axis_start=0, axis_end=1, dst=unique_vals, sycl_queue=exec_q, ) _manager.add_event_pair(ht_ev, uv_ev) cum_unique_counts = dpt.empty( n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q ) idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) _manager.add_event_pair(ht_ev, id_ev) ht_ev, extr_ev = _extract( src=idx, cumsum=cumsum, axis_start=0, axis_end=1, dst=cum_unique_counts[:-1], sycl_queue=exec_q, depends=[id_ev], ) _manager.add_event_pair(ht_ev, extr_ev) ht_ev, set_ev = _full_usm_ndarray( x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q ) _manager.add_event_pair(ht_ev, set_ev) _counts = dpt.empty_like(cum_unique_counts[1:]) ht_ev, sub_ev = _subtract( src1=cum_unique_counts[1:], src2=cum_unique_counts[:-1], dst=_counts, sycl_queue=exec_q, depends=[set_ev, extr_ev], ) _manager.add_event_pair(ht_ev, sub_ev) inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, positions=inv, sycl_queue=exec_q, depends=[ uv_ev, ], ) _manager.add_event_pair(ht_ev, ssl_ev) return UniqueAllResult( unique_vals, sorting_ids[cum_unique_counts[:-1]], inv, _counts, )