espnet2.enh.diffusion.sdes.batch_broadcast
Less than 1 minute
espnet2.enh.diffusion.sdes.batch_broadcast
espnet2.enh.diffusion.sdes.batch_broadcast(a, x)
Broadcasts a over all dimensions of x, except the batch dimension.
This function ensures that the tensor a is broadcasted across all dimensions of the tensor x, while preserving the batch dimension. The batch dimension must match between a and x, or a must be a scalar.
- Parameters:
- a (torch.Tensor) – The tensor to be broadcasted.
- x (torch.Tensor) – The tensor over which a will be broadcasted.
- Returns: A tensor of the same shape as x with a broadcasted across its dimensions.
- Return type: torch.Tensor
- Raises:
- ValueError – If a has more than one effective dimension after squeezing,
- or if the batch dimensions of a and x do not match. –
Examples
>>> import torch
>>> a = torch.tensor([1.0, 2.0, 3.0]) # shape: (3,)
>>> x = torch.zeros((5, 4, 2)) # shape: (5, 4, 2)
>>> result = batch_broadcast(a, x)
>>> result.shape
torch.Size([5, 4, 3])
>>> a = torch.tensor([1.0]) # shape: (1,)
>>> x = torch.zeros((5, 4, 2)) # shape: (5, 4, 2)
>>> result = batch_broadcast(a, x)
>>> result.shape
torch.Size([5, 4, 1])
>>> a = torch.tensor([1.0, 2.0]) # shape: (2,)
>>> x = torch.zeros((5, 4, 2)) # shape: (5, 4, 2)
>>> result = batch_broadcast(a, x)
>>> result.shape
torch.Size([5, 4, 2])