Source code for dpctl.tensor._search_functions

#                       Data Parallel Control (dpctl)
#
#  Copyright 2020-2024 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._manipulation_functions import _broadcast_shapes
from dpctl.utils import ExecutionPlacementError

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._type_utils import _all_data_types, _can_cast


def _where_result_type(dt1, dt2, dev):
    res_dtype = dpt.result_type(dt1, dt2)
    fp16 = dev.has_aspect_fp16
    fp64 = dev.has_aspect_fp64

    all_dts = _all_data_types(fp16, fp64)
    if res_dtype in all_dts:
        return res_dtype
    else:
        for res_dtype_ in all_dts:
            if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast(
                dt2, res_dtype_, fp16, fp64
            ):
                return res_dtype_
        return None


[docs]def where(condition, x1, x2, /, *, order="K", out=None): """ Returns :class:`dpctl.tensor.usm_ndarray` with elements chosen from ``x1`` or ``x2`` depending on ``condition``. Args: condition (usm_ndarray): When ``True`` yields from ``x1``, and otherwise yields from ``x2``. Must be compatible with ``x1`` and ``x2`` according to broadcasting rules. x1 (usm_ndarray): Array from which values are chosen when ``condition`` is ``True``. Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. x2 (usm_ndarray): Array from which values are chosen when ``condition`` is not ``True``. Must be compatible with ``condition`` and ``x2`` according to broadcasting rules. order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional): Memory layout of the new output arra, if parameter ``out`` is ``None``. Default: ``"K"``. out (Optional[usm_ndarray]): the array into which the result is written. The data type of `out` must match the expected shape and the expected data type of the result. If ``None`` then a new array is returned. Default: ``None``. Returns: usm_ndarray: An array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The data type of the returned array is determined by applying the Type Promotion Rules to ``x1`` and ``x2``. """ if not isinstance(condition, dpt.usm_ndarray): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}" ) if not isinstance(x1, dpt.usm_ndarray): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}" ) if not isinstance(x2, dpt.usm_ndarray): raise TypeError( "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}" ) if order not in ["K", "C", "F", "A"]: order = "K" exec_q = dpctl.utils.get_execution_queue( ( condition.sycl_queue, x1.sycl_queue, x2.sycl_queue, ) ) if exec_q is None: raise dpctl.utils.ExecutionPlacementError out_usm_type = dpctl.utils.get_coerced_usm_type( ( condition.usm_type, x1.usm_type, x2.usm_type, ) ) x1_dtype = x1.dtype x2_dtype = x2.dtype out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device) if out_dtype is None: raise TypeError( "function 'where' does not support input " f"types ({x1_dtype}, {x2_dtype}), " "and the inputs could not be safely coerced " "to any supported types according to the casting rule ''safe''." ) res_shape = _broadcast_shapes(condition, x1, x2) orig_out = out if out is not None: if not isinstance(out, dpt.usm_ndarray): raise TypeError( "output array must be of usm_ndarray type, got " f"{type(out)}" ) if not out.flags.writable: raise ValueError("provided `out` array is read-only") if out.shape != res_shape: raise ValueError( "The shape of input and output arrays are " f"inconsistent. Expected output shape is {res_shape}, " f"got {out.shape}" ) if out_dtype != out.dtype: raise ValueError( f"Output array of type {out_dtype} is needed, " f"got {out.dtype}" ) if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: raise ExecutionPlacementError( "Input and output allocation queues are not compatible" ) if ti._array_overlap(condition, out): if not ti._same_logical_tensors(condition, out): out = dpt.empty_like(out) if ti._array_overlap(x1, out): if not ti._same_logical_tensors(x1, out): out = dpt.empty_like(out) if ti._array_overlap(x2, out): if not ti._same_logical_tensors(x2, out): out = dpt.empty_like(out) if order == "A": order = ( "F" if all( arr.flags.f_contiguous for arr in ( condition, x1, x2, ) ) else "C" ) if condition.size == 0: if out is not None: return out else: if order == "K": return _empty_like_triple_orderK( condition, x1, x2, out_dtype, res_shape, out_usm_type, exec_q, ) else: return dpt.empty( res_shape, dtype=out_dtype, order=order, usm_type=out_usm_type, sycl_queue=exec_q, ) deps = [] wait_list = [] if x1_dtype != out_dtype: if order == "K": _x1 = _empty_like_orderK(x1, out_dtype) else: _x1 = dpt.empty_like(x1, dtype=out_dtype, order=order) ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x1, dst=_x1, sycl_queue=exec_q ) x1 = _x1 deps.append(copy1_ev) wait_list.append(ht_copy1_ev) if x2_dtype != out_dtype: if order == "K": _x2 = _empty_like_orderK(x2, out_dtype) else: _x2 = dpt.empty_like(x2, dtype=out_dtype, order=order) ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x2, dst=_x2, sycl_queue=exec_q ) x2 = _x2 deps.append(copy2_ev) wait_list.append(ht_copy2_ev) if out is None: if order == "K": out = _empty_like_triple_orderK( condition, x1, x2, out_dtype, res_shape, out_usm_type, exec_q ) else: out = dpt.empty( res_shape, dtype=out_dtype, order=order, usm_type=out_usm_type, sycl_queue=exec_q, ) condition = dpt.broadcast_to(condition, res_shape) x1 = dpt.broadcast_to(x1, res_shape) x2 = dpt.broadcast_to(x2, res_shape) hev, where_ev = ti._where( condition=condition, x1=x1, x2=x2, dst=out, sycl_queue=exec_q, depends=deps, ) if not (orig_out is None or orig_out is out): # Copy the out data from temporary buffer to original memory ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=out, dst=orig_out, sycl_queue=exec_q, depends=[where_ev], ) ht_copy_out_ev.wait() out = orig_out dpctl.SyclEvent.wait_for(wait_list) hev.wait() return out