espnet2.enh.separator.tfgridnetv3_separator.AllHeadPReLULayerNormalization4DC
espnet2.enh.separator.tfgridnetv3_separator.AllHeadPReLULayerNormalization4DC
class espnet2.enh.separator.tfgridnetv3_separator.AllHeadPReLULayerNormalization4DC(input_dimension, eps=1e-05)
Bases: Module
Applies PReLU activation followed by layer normalization across heads.
This layer normalizes the input tensor along specified dimensions after applying the PReLU activation function. It is designed for multi-dimensional inputs, particularly suited for use in attention mechanisms where inputs are structured with heads and embedding dimensions.
gamma
Scale parameter for layer normalization.
- Type: torch.Parameter
beta
Shift parameter for layer normalization.
- Type: torch.Parameter
act
PReLU activation function applied per head.
- Type: nn.PReLU
eps
Small value added for numerical stability during normalization.
- Type: float
H
Number of heads in the input dimension.
- Type: int
E
Embedding dimension in the input.
Type: int
Parameters:
- input_dimension (Tuple *[*int , int ]) – A tuple containing the number of heads (H) and the embedding dimension (E).
- eps (float) – Small value to prevent division by zero in normalization.
Raises:AssertionError – If the input_dimension does not have a length of 2.
###
E
>>> layer_norm = AllHeadPReLULayerNormalization4DC((8, 64))
>>> input_tensor = torch.randn(32, 8, 128, 64) # [B, H, T, F]
>>> output = layer_norm(input_tensor)
>>> output.shape
torch.Size([32, 8, 128, 64]) # Normalized output shape remains the same.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the TFGridNetV3 model.
This method processes the input audio tensor and returns the enhanced audio signals along with their corresponding lengths and any additional data.
- Parameters:
- input (torch.Tensor) – A batched multi-channel audio tensor with shape [B, T, F], where B is the batch size, T is the number of time frames, and F is the number of frequency bins.
- ilens (torch.Tensor) – A tensor containing the lengths of each input sequence in the batch, shape [B].
- additional (Dict or None) – Additional data that may be required for processing. Currently unused in this model.
- Returns:
- enhanced (List[torch.Tensor]): A list of enhanced mono audio tensors of shape [(B, T), …] where the length of the list is equal to n_srcs (number of output sources).
- ilens (torch.Tensor): A tensor of shape [B] containing the lengths of the enhanced audio signals.
- additional (OrderedDict): The additional data returned as-is, currently unused in this model.
- Return type: Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]
###
E
>>> model = TFGridNetV3(n_srcs=2)
>>> input_tensor = torch.randn(8, 100, 2) # 8 samples, 100 time frames, 2 channels
>>> ilens = torch.tensor([100] * 8) # All samples have length 100
>>> enhanced, lengths, _ = model(input_tensor, ilens)
>>> print([e.shape for e in enhanced]) # Output shapes of enhanced signals
NOTE
Ensure the input tensor is properly normalized. This model works best when the input mixture and target signals are variance normalized.
- Raises:
- AssertionError – If the input tensor does not have the correct number
- of channels or if any other assertion within the method fails. –