espnet2.asr.encoder.beats_encoder.Swish
espnet2.asr.encoder.beats_encoder.Swish
class espnet2.asr.encoder.beats_encoder.Swish
Bases: Module
Swish activation function.
The Swish activation function is defined as: : Swish(x) = x * sigmoid(x)
This activation function has been shown to outperform ReLU in some deep learning applications.
act
The sigmoid activation function instance.
- Type: torch.nn.Sigmoid
forward(x
torch.Tensor) -> torch.Tensor: Computes the Swish activation for the input tensor.
####### Examples
>>> swish = Swish()
>>> input_tensor = torch.tensor([-1.0, 0.0, 1.0])
>>> output_tensor = swish(input_tensor)
>>> print(output_tensor)
tensor([-0.2689, 0.0000, 0.7311])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
#
forward(x
Forward pass for the Beats encoder.
This method processes the input tensor xs_pad and its corresponding lengths ilens to produce audio representations. It is designed to be compatible with the AbsEncoder interface in ESPnet.
Parameters:
- xs_pad (torch.Tensor) – Input tensor of shape (B, T, D), where B is the batch size, T is the sequence length, and D is the feature dimension.
- ilens (torch.Tensor) – Tensor of shape (B,) representing the actual lengths of the sequences in xs_pad.
- prev_states (torch.Tensor , optional) – Optional tensor representing previous states. Default is None.
Returns:
- audio_representation (torch.Tensor): The output audio
representation of shape (B, T, D).
- output_lens (torch.Tensor): Tensor of shape (B,) containing the output lengths for each sequence.
- masks (Optional[torch.Tensor]): Currently set to None.
Return type: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
NOTE
If xs_pad is not provided, this operation may be costly, as it attempts to create a tensor of size maxlen x maxlen. Thus, tensors are unsqueezed and then squeezed to optimize performance.
####### Examples
>>> beats_encoder = BeatsEncoder(...)
>>> xs_pad = torch.randn(32, 100, 512) # Batch of 32, 100 time steps
>>> ilens = torch.randint(1, 101, (32,)) # Random lengths
>>> audio_representation, output_lens, masks = beats_encoder.forward(xs_pad, ilens)
- Raises:ValueError – If xs_pad or ilens have incompatible shapes.