espnet2.gan_svs.pits.modules.WN
espnet2.gan_svs.pits.modules.WN
class espnet2.gan_svs.pits.modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0)
Bases: Module
WN is a WaveNet-like neural network module designed for generative tasks.
This module implements a series of dilated convolutions, allowing for efficient processing of sequential data. It includes options for conditional input and dropout regularization. The architecture supports multiple layers with varying dilation rates, which helps capture long-range dependencies in the input data.
hidden_channels
The number of hidden channels in the network.
- Type: int
kernel_size
The size of the convolution kernel.
- Type: tuple
dilation_rate
The rate of dilation for the convolutions.
- Type: int
n_layers
The number of convolutional layers in the network.
- Type: int
gin_channels
The number of channels for conditional input (default 0).
- Type: int
p_dropout
The dropout probability for regularization (default 0).
- Type: float
in_layers
List of input convolutional layers.
- Type: ModuleList
res_skip_layers
List of residual and skip connection layers.
- Type: ModuleList
drop
Dropout layer for regularization.
- Type: Dropout
cond_layer
Conditional layer for processing additional input.
Type:Conv1d
Parameters:
- hidden_channels (int) – Number of hidden channels in the network.
- kernel_size (int) – Size of the convolution kernel (must be odd).
- dilation_rate (int) – Dilation rate for convolutions.
- n_layers (int) – Number of layers in the network.
- gin_channels (int , optional) – Number of input channels for conditional input. Defaults to 0 (no conditional input).
- p_dropout (float , optional) – Probability of dropout. Defaults to 0.
Returns: The output tensor after processing through the network.
Return type: Tensor
Raises:AssertionError – If kernel_size is not odd.
########### Examples
>>> model = WN(hidden_channels=64, kernel_size=3, dilation_rate=2,
... n_layers=5, gin_channels=10, p_dropout=0.1)
>>> x = torch.randn(1, 64, 100) # Batch size 1, 64 channels, 100 length
>>> x_mask = torch.ones(1, 1, 100) # No masking
>>> output = model(x, x_mask)
######## NOTE The input tensor x should have the shape (batch_size, hidden_channels, sequence_length). The mask tensor x_mask should have the shape (batch_size, 1, sequence_length) and is used to mask the output.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, x_mask, g=None, **kwargs)
Performs the forward pass of the WN model.
This method computes the output of the WN model given the input tensor x, an input mask x_mask, and an optional conditioning tensor g. It applies several layers of convolutions followed by non-linear activations and dropout. The output is computed as a weighted sum of the input and the skip connections from the convolutional layers.
- Parameters:
- x (torch.Tensor) – The input tensor of shape (batch_size, hidden_channels, sequence_length).
- x_mask (torch.Tensor) – A binary mask tensor of shape (batch_size, 1, sequence_length) to control the contribution of each time step in the input.
- g (torch.Tensor , optional) – An optional conditioning tensor of shape (batch_size, gin_channels, sequence_length). If provided, it is passed through a conditioning layer.
- **kwargs – Additional keyword arguments for future extension.
- Returns: The output tensor of shape (batch_size, hidden_channels, : sequence_length), after applying the WN model transformations.
- Return type: torch.Tensor
########### Examples
>>> model = WN(hidden_channels=64, kernel_size=3, dilation_rate=2,
... n_layers=4)
>>> x = torch.randn(10, 64, 50) # Example input
>>> x_mask = torch.ones(10, 1, 50) # Example mask
>>> output = model.forward(x, x_mask)
>>> output.shape
torch.Size([10, 64, 50])
######## NOTE This method utilizes weight normalization for the convolutional layers to stabilize training.
- Raises:ValueError – If the input tensor x does not match the expected shape.
fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels)
Computes the fused operation of addition, tanh, sigmoid, and multiplication.
This function takes two input tensors, adds them together, applies the tanh and sigmoid activation functions to the result, and finally multiplies the outputs of the tanh and sigmoid functions.
input_a
The first input tensor.
- Type: torch.Tensor
input_b
The second input tensor.
- Type: torch.Tensor
n_channels
A tensor containing the number of channels.
Type: torch.IntTensor
Parameters:
- input_a (torch.Tensor) – The first input tensor of shape (batch_size, hidden_channels, sequence_length).
- input_b (torch.Tensor) – The second input tensor of shape (batch_size, hidden_channels, sequence_length).
- n_channels (torch.IntTensor) – A tensor containing the number of hidden channels as its first element.
Returns: The result of the fused operation, with shape : (batch_size, hidden_channels, sequence_length).
Return type: torch.Tensor
########### Examples
>>> input_a = torch.randn(10, 16, 50) # batch_size=10, hidden_channels=16
>>> input_b = torch.randn(10, 16, 50)
>>> n_channels = torch.IntTensor([16])
>>> output = fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels)
>>> output.shape
torch.Size([10, 16, 50])
######## NOTE This function assumes that the first dimension of both input tensors is the batch size and the second dimension corresponds to the number of channels.
remove_weight_norm()
Remove weight normalization from the WN model layers.
This method removes weight normalization from all layers of the WN model, including the conditional layer (if it exists), input layers, and residual-skip layers. This is typically used to revert the layers back to their original state after weight normalization has been applied.
gin_channels
Number of conditional input channels. If greater than 0, the conditional layer will also have weight normalization removed.
- Type: int
########### Examples
>>> model = WN(hidden_channels=64, kernel_size=3, dilation_rate=2,
... n_layers=4, gin_channels=0)
>>> model.remove_weight_norm() # Removes weight normalization from model
######## NOTE This method does not return any value but modifies the model’s layers in place.
- Raises:ValueError – If the model has not been initialized properly.