espnet2.gan_tts.wavenet.residual_block.ResidualBlock
espnet2.gan_tts.wavenet.residual_block.ResidualBlock
class espnet2.gan_tts.wavenet.residual_block.ResidualBlock(kernel_size: int = 3, residual_channels: int = 64, gate_channels: int = 128, skip_channels: int = 64, aux_channels: int = 80, global_channels: int = -1, dropout_rate: float = 0.0, dilation: int = 1, bias: bool = True, scale_residual: bool = False)
Bases: Module
Residual block module in WaveNet.
This module implements a residual block used in the WaveNet architecture. It incorporates convolutional layers with gated activation and allows for local and global conditioning. This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
dropout_rate
The probability of dropout applied to the input.
- Type: float
residual_channels
Number of channels for the residual connection.
- Type: int
skip_channels
Number of channels for the skip connection.
- Type: int
scale_residual
Whether to scale the residual outputs.
Type: bool
Parameters:
- kernel_size (int) – Kernel size of dilation convolution layer.
- residual_channels (int) – Number of channels for residual connection.
- gate_channels (int) – Number of channels for gating mechanism.
- skip_channels (int) – Number of channels for skip connection.
- aux_channels (int) – Number of local conditioning channels.
- global_channels (int) – Number of global conditioning channels.
- dropout_rate (float) – Dropout probability.
- dilation (int) – Dilation factor.
- bias (bool) – Whether to add bias parameter in convolution layers.
- scale_residual (bool) – Whether to scale the residual outputs.
####### Examples
>>> residual_block = ResidualBlock()
>>> x = torch.randn(1, 64, 100) # Example input tensor
>>> output, skip = residual_block(x)
- Raises:AssertionError – If the kernel size is even or gate channels are not even.
Initialize ResidualBlock module.
- Parameters:
- kernel_size (int) – Kernel size of dilation convolution layer.
- residual_channels (int) – Number of channels for residual connection.
- skip_channels (int) – Number of channels for skip connection.
- aux_channels (int) – Number of local conditioning channels.
- dropout (float) – Dropout probability.
- dilation (int) – Dilation factor.
- bias (bool) – Whether to add bias parameter in convolution layers.
- scale_residual (bool) – Whether to scale the residual outputs.
forward(x: Tensor, x_mask: Tensor | None = None, c: Tensor | None = None, g: Tensor | None = None) → Tuple[Tensor, Tensor]
Calculate forward propagation through the ResidualBlock.
This method computes the forward pass of the residual block, taking into account the input tensor, optional local and global conditioning tensors, and an optional mask tensor for attention.
- Parameters:
- x (Tensor) – Input tensor of shape (B, residual_channels, T).
- x_mask (Optional *[*torch.Tensor ]) – Mask tensor of shape (B, 1, T). Used to zero out certain parts of the output.
- c (Optional *[*Tensor ]) – Local conditioning tensor of shape (B, aux_channels, T).
- g (Optional *[*Tensor ]) – Global conditioning tensor of shape (B, global_channels, 1).
- Returns: A tuple containing: : - Output tensor for residual connection of shape (B, residual_channels, T).
- Output tensor for skip connection of shape (B, skip_channels, T).
- Return type: Tuple[Tensor, Tensor]
####### Examples
>>> residual_block = ResidualBlock()
>>> x = torch.randn(1, 64, 100) # Example input tensor
>>> output_residual, output_skip = residual_block(x)
>>> print(output_residual.shape) # Should be (1, 64, 100)
>>> print(output_skip.shape) # Should be (1, 64, 100)