dpctl.tensor.put

dpctl.tensor.put(x, indices, vals, axis=None, mode='wrap')[source]

Puts values into an array along a given axis at given indices.

Parameters:
  • x (usm_ndarray) – The array the values will be put into.

  • indices (usm_ndarray) – One-dimensional array of indices.

  • vals (usm_ndarray) – Array of values to be put into x. Must be broadcastable to the result shape x.shape[:axis] + indices.shape + x.shape[axis+1:].

  • axis (int, optional) – The axis along which the values will be placed. If x is one-dimensional, this argument is optional. Default: None.

  • mode (str, optional) –

    How out-of-bounds indices will be handled. Possible values are:

    • "wrap": clamps indices to (-n <= i < n), then wraps negative indices.

    • "clip": clips indices to (0 <= i < n).

    Default: "wrap".

Note

If input array indices contains duplicates, a race condition occurs, and the value written into corresponding positions in x may vary from run to run. Preserving sequential semantics in handing the duplicates to achieve deterministic behavior requires additional work, e.g.

Example:
from dpctl import tensor as dpt

def put_vec_duplicates(vec, ind, vals):
    "Put values into vec, handling possible duplicates in ind"
    assert vec.ndim, ind.ndim, vals.ndim == 1, 1, 1

    # find positions of last occurences of each
    # unique index
    ind_flipped = dpt.flip(ind)
    ind_uniq = dpt.unique_all(ind_flipped).indices
    has_dups = len(ind) != len(ind_uniq)

    if has_dups:
        ind_uniq = dpt.subtract(vec.size - 1, ind_uniq)
        ind = dpt.take(ind, ind_uniq)
        vals = dpt.take(vals, ind_uniq)

    dpt.put(vec, ind, vals)

n = 512
ind = dpt.concat((dpt.arange(n), dpt.arange(n, -1, step=-1)))
x = dpt.zeros(ind.size, dtype="int32")
vals = dpt.arange(ind.size, dtype=x.dtype)

# Values corresponding to last positions of
# duplicate indices are written into the vector x
put_vec_duplicates(x, ind, vals)

parts = (vals[-1:-n-2:-1], dpt.zeros(n, dtype=x.dtype))
expected = dpt.concat(parts)
assert dpt.all(x == expected)