espnet2.enh.layers.dpmulcat.MulCatBlock
espnet2.enh.layers.dpmulcat.MulCatBlock
class espnet2.enh.layers.dpmulcat.MulCatBlock(input_size: int, hidden_size: int, dropout: float = 0.0, bidirectional: bool = True)
Bases: Module
The MulCat block.
This module implements a multiplicative concatenation block using LSTM layers. It processes input sequences through two separate LSTM networks: one for the main processing and another to create a gating mechanism. The outputs are combined to enhance the feature representation.
rnn
The primary LSTM layer for feature extraction.
- Type: nn.LSTM
rnn
Linear layer to project the output of the RNN.
- Type: nn.Linear
gate_rnn
The gating LSTM layer.
- Type: nn.LSTM
gate_rnn
Linear layer to project the output of the gate RNN.
- Type: nn.Linear
block_projection
Final linear projection to match input size.
Type: nn.Linear
Parameters:
- input_size (int) – Dimension of the input feature. The input should have shape (batch, seq_len, input_size).
- hidden_size (int) – Dimension of the hidden state.
- dropout (float , optional) – The dropout rate in the LSTM layer. Defaults to 0.0.
- bidirectional (bool , optional) – Whether the RNN layers are bidirectional. Defaults to True.
####### Examples
>>> input_tensor = torch.randn(32, 10, 64) # (batch_size, seq_len, input_size)
>>> mulcat_block = MulCatBlock(input_size=64, hidden_size=32)
>>> output_tensor = mulcat_block(input_tensor)
>>> output_tensor.shape
torch.Size([32, 10, 64]) # Output has the same shape as input
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
The MulCat block.
This module implements a MulCat block that processes input features through LSTM layers and applies gating mechanisms to produce an output feature.
- Parameters:
- input_size (int) – Dimension of the input feature. The input should have shape (batch, seq_len, input_size).
- hidden_size (int) – Dimension of the hidden state.
- dropout (float , optional) – The dropout rate in the LSTM layer. (Default: 0.0)
- bidirectional (bool , optional) – Whether the RNN layers are bidirectional. (Default: True)
####### Examples
>>> mul_cat_block = MulCatBlock(input_size=128, hidden_size=64)
>>> input_tensor = torch.randn(32, 10, 128) # (batch, seq_len, input_size)
>>> output_tensor = mul_cat_block(input_tensor)
>>> print(output_tensor.shape) # (32, 10, 128)