espnet2.enh.layers.dpmulcat.DPMulCat
espnet2.enh.layers.dpmulcat.DPMulCat
class espnet2.enh.layers.dpmulcat.DPMulCat(input_size: int, hidden_size: int, output_size: int, num_spk: int, dropout: float = 0.0, num_layers: int = 4, bidirectional: bool = True, input_normalize: bool = False)
Bases: Module
Dual-path RNN module with MulCat blocks.
This module implements a dual-path RNN architecture that utilizes MulCat blocks to process input features in both row and column dimensions. It allows for flexible handling of multi-speaker scenarios and incorporates optional normalization.
rows_grnn
List of MulCat blocks for row processing.
- Type: nn.ModuleList
cols_grnn
List of MulCat blocks for column processing.
- Type: nn.ModuleList
rows_normalization
List of normalization layers for rows.
- Type: nn.ModuleList
cols_normalization
List of normalization layers for columns.
- Type: nn.ModuleList
output
Final layer for producing the output.
Type: nn.Sequential
Parameters:
- input_size (int) – Dimension of the input feature. The input should have shape (batch, seq_len, input_size).
- hidden_size (int) – Dimension of the hidden state.
- output_size (int) – Dimension of the output size.
- num_spk (int) – The number of speakers in the output.
- dropout (float) – The dropout rate in the LSTM layer. (Default: 0.0)
- bidirectional (bool) – Whether the RNN layers are bidirectional. (Default: True)
- num_layers (int) – Number of stacked MulCat blocks. (Default: 4)
- input_normalize (bool) – Whether to apply GroupNorm on the input Tensor. (Default: False)
####### Examples
>>> dp_mul_cat = DPMulCat(input_size=64, hidden_size=128, output_size=10,
... num_spk=2)
>>> input_tensor = torch.randn(32, 10, 64) # (batch, seq_len, input_size)
>>> output = dp_mul_cat(input_tensor)
>>> len(output) # Should return 4 if num_layers is 4
- Returns: In training mode, the module returns output of each DPMulCat block. In eval mode, the module only returns output in the last block.
- Return type: list(torch.Tensor) or list(list(torch.Tensor))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Compute output after DPMulCat module.
- Parameters:input (torch.Tensor) – The input feature. Tensor of shape (batch, N, dim1, dim2) Apply RNN on dim1 first and then dim2.
- Returns: In training mode, the module returns output of each DPMulCat block. In eval mode, the module only returns output in the last block.
- Return type: (list(torch.Tensor) or list(list(torch.Tensor)))
####### Examples
>>> model = DPMulCat(input_size=128, hidden_size=64, output_size=10, num_spk=2)
>>> input_tensor = torch.randn(32, 10, 20, 20) # (batch, N, dim1, dim2)
>>> output = model(input_tensor)
>>> print(len(output)) # Output length in training mode