espnet2.enh.separator.tfgridnetv3_separator.LayerNormalization
espnet2.enh.separator.tfgridnetv3_separator.LayerNormalization
class espnet2.enh.separator.tfgridnetv3_separator.LayerNormalization(input_dim, dim=1, total_dim=4, eps=1e-05)
Bases: Module
Layer normalization layer.
This layer applies layer normalization to the input tensor along a specified dimension. It normalizes the input by subtracting the mean and dividing by the standard deviation, followed by scaling and shifting with learnable parameters gamma and beta.
dim
The dimension along which to compute the mean and variance.
- Type: int
gamma
Scale parameter for normalization.
- Type: nn.Parameter
beta
Shift parameter for normalization.
- Type: nn.Parameter
eps
A small value added to the variance to avoid division by zero.
Type: float
Parameters:
- input_dim (int) – The dimension of the input tensor to normalize.
- dim (int) – The dimension along which to compute the normalization. Default is 1.
- total_dim (int) – The total number of dimensions of the input tensor. Default is 4.
- eps (float) – A small value to prevent division by zero during normalization. Default is 1e-5.
Raises:ValueError – If the input tensor does not have the expected number of dimensions.
####### Examples
>>> layer_norm = LayerNormalization(input_dim=64)
>>> input_tensor = torch.randn(32, 64, 128, 256) # [B, C, T, F]
>>> output_tensor = layer_norm(input_tensor)
>>> output_tensor.shape
torch.Size([32, 64, 128, 256]) # Output has the same shape as input
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass for the TFGridNetV3 model.
This method takes a batched multi-channel audio tensor as input and processes it through the layers of the TFGridNetV3 model, outputting the enhanced audio signals for each source.
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with shape [B, T, F], where B is the batch size, T is the number of samples, and F is the number of audio channels.
- ilens (torch.Tensor) – Input lengths of shape [B], indicating the length of each input sequence in the batch.
- additional (Dict or None) – Other data, currently unused in this model.
Returns: A list of length n_srcs, each : containing mono audio tensors with shape [B, T].
ilens (torch.Tensor): Tensor of shape [B] representing the input lengths. additional (Dict or None): Returns the additional data, currently unused
in this model.
Return type: enhanced (List[Union(torch.Tensor)])
Raises:AssertionError – If the input is not a single-channel mixture.
####### Examples
>>> model = TFGridNetV3(n_srcs=2)
>>> input_tensor = torch.randn(4, 256, 2) # Batch of 4, 256 samples, 2 channels
>>> ilens = torch.tensor([256, 256, 256, 256]) # Input lengths
>>> enhanced, lengths, _ = model(input_tensor, ilens)
>>> print(len(enhanced)) # Should be equal to n_srcs (2)
>>> print(enhanced[0].shape) # Shape of enhanced output for the first source