espnet2.gan_tts.wavenet.wavenet.WaveNet
espnet2.gan_tts.wavenet.wavenet.WaveNet
class espnet2.gan_tts.wavenet.wavenet.WaveNet(in_channels: int = 1, out_channels: int = 1, kernel_size: int = 3, layers: int = 30, stacks: int = 3, base_dilation: int = 2, residual_channels: int = 64, aux_channels: int = -1, gate_channels: int = 128, skip_channels: int = 64, global_channels: int = -1, dropout_rate: float = 0.0, bias: bool = True, use_weight_norm: bool = True, use_first_conv: bool = False, use_last_conv: bool = False, scale_residual: bool = False, scale_skip_connect: bool = False)
Bases: Module
WaveNet with global conditioning.
This class implements a WaveNet model that can be used for generating audio signals with global conditioning. It is built using residual blocks and dilated convolutions.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
layers
Number of residual block layers.
- Type: int
stacks
Number of stacks, i.e., dilation cycles.
- Type: int
kernel_size
Kernel size of dilated convolution.
- Type: int
base_dilation
Base dilation factor.
- Type: int
use_first_conv
Whether to use the first convolution layers.
- Type: bool
use_last_conv
Whether to use the last convolution layers.
- Type: bool
scale_skip_connect
Whether to scale the skip connection outputs.
Type: bool
Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- kernel_size (int) – Kernel size of dilated convolution.
- layers (int) – Number of residual block layers.
- stacks (int) – Number of stacks i.e., dilation cycles.
- base_dilation (int) – Base dilation factor.
- residual_channels (int) – Number of channels in residual conv.
- gate_channels (int) – Number of channels in gated conv.
- skip_channels (int) – Number of channels in skip conv.
- aux_channels (int) – Number of channels for local conditioning feature.
- global_channels (int) – Number of channels for global conditioning feature.
- dropout_rate (float) – Dropout rate. 0.0 means no dropout applied.
- bias (bool) – Whether to use bias parameter in conv layer.
- use_weight_norm (bool) – Whether to use weight norm. If set to true, it will be applied to all of the conv layers.
- use_first_conv (bool) – Whether to use the first conv layers.
- use_last_conv (bool) – Whether to use the last conv layers.
- scale_residual (bool) – Whether to scale the residual outputs.
- scale_skip_connect (bool) – Whether to scale the skip connection outputs.
Returns: Output tensor (B, out_channels, T) if use_last_conv else : (B, residual_channels, T).
Return type: Tensor
########### Examples
Create a WaveNet instance
wavenet = WaveNet(in_channels=1, out_channels=1, layers=30, stacks=3)
Generate a random input tensor
input_tensor = torch.randn(1, 1, 100)
Perform forward propagation
output_tensor = wavenet(input_tensor)
####### NOTE The forward method accepts optional masks and conditioning features for local and global conditioning.
Initialize WaveNet module.
- Parameters:
- in_channels (int) – Number of input channels.
- out_channels (int) – Number of output channels.
- kernel_size (int) – Kernel size of dilated convolution.
- layers (int) – Number of residual block layers.
- stacks (int) – Number of stacks i.e., dilation cycles.
- base_dilation (int) – Base dilation factor.
- residual_channels (int) – Number of channels in residual conv.
- gate_channels (int) – Number of channels in gated conv.
- skip_channels (int) – Number of channels in skip conv.
- aux_channels (int) – Number of channels for local conditioning feature.
- global_channels (int) – Number of channels for global conditioning feature.
- dropout_rate (float) – Dropout rate. 0.0 means no dropout applied.
- bias (bool) – Whether to use bias parameter in conv layer.
- use_weight_norm (bool) – Whether to use weight norm. If set to true, it will be applied to all of the conv layers.
- use_first_conv (bool) – Whether to use the first conv layers.
- use_last_conv (bool) – Whether to use the last conv layers.
- scale_residual (bool) – Whether to scale the residual outputs.
- scale_skip_connect (bool) – Whether to scale the skip connection outputs.
apply_weight_norm()
Apply weight normalization module to all convolutional layers.
This method applies weight normalization to all Conv1d and Conv2d layers in the WaveNet model. Weight normalization helps in stabilizing the training process and can lead to faster convergence.
It uses the torch.nn.utils.weight_norm function to apply the weight normalization. The application is logged for debugging purposes.
####### NOTE This method should be called after the model’s layers have been initialized.
########### Examples
>>> model = WaveNet(use_weight_norm=True)
>>> model.apply_weight_norm() # This will apply weight normalization
forward(x: Tensor, x_mask: Tensor | None = None, c: Tensor | None = None, g: Tensor | None = None) → Tensor
Calculate forward propagation.
This method performs the forward pass through the WaveNet model. It takes an input tensor and optional conditioning features, and returns the output tensor after processing through the residual blocks and any specified convolutional layers.
- Parameters:
- x (Tensor) – Input noise signal of shape (B, 1, T) if use_first_conv is True, else shape (B, residual_channels, T).
- x_mask (Optional *[*Tensor ]) – Mask tensor of shape (B, 1, T) to apply masking during the forward pass.
- c (Optional *[*Tensor ]) – Local conditioning features of shape (B, aux_channels, T).
- g (Optional *[*Tensor ]) – Global conditioning features of shape (B, global_channels, 1).
- Returns: Output tensor of shape (B, out_channels, T) if : use_last_conv is True, else shape (B, residual_channels, T).
- Return type: Tensor
########### Examples
>>> model = WaveNet()
>>> input_tensor = torch.randn(8, 1, 100) # Batch size 8, T=100
>>> output = model(input_tensor)
>>> print(output.shape) # Should match the expected output shape
property receptive_field_size : int
Return receptive field size.
remove_weight_norm()
Remove weight normalization module from all of the layers.
This method traverses through all layers of the WaveNet model and removes the weight normalization applied to the convolutional layers. It is useful for reducing the model size and complexity during inference or when weight normalization is no longer needed.
####### NOTE This operation is irreversible; once weight normalization is removed, it cannot be reapplied unless explicitly done so again.
########### Examples
>>> wavenet = WaveNet(use_weight_norm=True)
>>> wavenet.remove_weight_norm() # Weight normalization is removed.
- Raises:
- ValueError – If a layer does not have weight normalization applied,
- this exception is caught and logged. –