espnet2.asr.encoder.avhubert_encoder.index_put
Less than 1 minute
espnet2.asr.encoder.avhubert_encoder.index_put
espnet2.asr.encoder.avhubert_encoder.index_put(tensor, indices, value)
Updates elements of a tensor at specified indices with a given value.
This function modifies the input tensor in-place by assigning the specified value to the positions indicated by the indices. It handles tensors located on XLA devices (e.g., TPU) differently to ensure compatibility.
- Parameters:
- tensor (torch.Tensor) – The input tensor to be updated.
- indices (torch.Tensor) – A tensor containing the indices where the value should be placed. This tensor should have the same number of dimensions as the tensor being modified.
- value (torch.Tensor) – The value to assign at the specified indices. This should be broadcastable to the shape of the indices.
- Returns: The updated tensor with values assigned at the specified indices.
- Return type: torch.Tensor
Examples
>>> tensor = torch.tensor([[1, 2], [3, 4]])
>>> indices = torch.tensor([[0, 1], [1, 0]])
>>> value = torch.tensor([[5, 6], [7, 8]])
>>> updated_tensor = index_put(tensor, indices, value)
>>> print(updated_tensor)
tensor([[5, 6],
[8, 4]])
NOTE
If the input tensor is an XLA tensor, the function will ensure that the operation is performed correctly according to XLA tensor handling requirements.
- Raises:IndexError – If the indices are out of bounds for the input tensor.