espnet2.enh.separator.tfgridnet_separator.LayerNormalization4D
espnet2.enh.separator.tfgridnet_separator.LayerNormalization4D
class espnet2.enh.separator.tfgridnet_separator.LayerNormalization4D(input_dimension, eps=1e-05)
Bases: Module
4D Layer Normalization.
This class implements layer normalization for 4-dimensional tensors. It normalizes the input tensor along specified dimensions and scales and shifts the result using learnable parameters.
- Parameters:
- input_dimension (int) – The size of the input feature dimension.
- eps (float , optional) – A small constant added to the variance for numerical stability. Default is 1e-5.
- Raises:ValueError – If the input tensor does not have 4 dimensions.
####### Examples
>>> layer_norm = LayerNormalization4D(input_dimension=64)
>>> input_tensor = torch.randn(10, 64, 32, 32) # [B, C, H, W]
>>> output_tensor = layer_norm(input_tensor)
>>> print(output_tensor.shape)
torch.Size([10, 64, 32, 32])
NOTE
The input tensor is expected to have 4 dimensions, where the dimensions represent batch size, number of channels, height, and width respectively.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass through the TFGridNet model.
This method processes the input multi-channel audio tensor and returns the enhanced audio signals for each source.
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples [B, N, M].
- ilens (torch.Tensor) – Input lengths [B].
- additional (Dict or None) – Other data, currently unused in this model.
Returns: A list of length n_srcs, : containing mono audio tensors with T samples for each source.
ilens (torch.Tensor): The input lengths [B]. additional (OrderedDict): Other data, currently unused in this model,
returned in output.
Return type: enhanced (List[Union(torch.Tensor)])
####### Examples
>>> model = TFGridNet(n_srcs=2)
>>> input_tensor = torch.randn(4, 16000, 1) # [B, N, M]
>>> ilens = torch.tensor([16000, 16000, 16000, 16000]) # [B]
>>> enhanced, ilens_out, _ = model(input_tensor, ilens)
>>> len(enhanced) # Should be equal to n_srcs
2
NOTE
The model works best when trained with variance normalized mixture input and target. For example, normalize the mixture of shape [batch, samples, microphones] by dividing with torch.std(mixture, (1, 2)). This must also be done for the target signals, especially when not using scale-invariant loss functions such as SI-SDR.