espnet2.enh.layers.dptnet.DPTNet
espnet2.enh.layers.dptnet.DPTNet
class espnet2.enh.layers.dptnet.DPTNet(rnn_type, input_size, hidden_size, output_size, att_heads=4, dropout=0, activation='relu', num_layers=1, bidirectional=True, norm_type='gLN')
Bases: Module
Dual-path transformer network.
This implementation of DPTNet is based on the work by J. Chen, Q. Mao, and D. Liu, “Dual-path transformer network: Direct context-aware modeling for end-to-end monaural speech separation,” presented at ISCA Interspeech, 2020. It utilizes an improved transformer layer to process input data through a dual-path approach.
input_size
Dimension of the input feature.
- Type: int
hidden_size
Dimension of the hidden state.
- Type: int
output_size
Dimension of the output size.
- Type: int
row_transformer
List of transformer layers for row processing.
- Type: nn.ModuleList
col_transformer
List of transformer layers for column processing.
- Type: nn.ModuleList
output
Final output layer consisting of PReLU and Conv2d.
Type: nn.Sequential
Parameters:
- rnn_type (str) – Select from ‘RNN’, ‘LSTM’, and ‘GRU’.
- input_size (int) – Dimension of the input feature. Input size must be a multiple of att_heads.
- hidden_size (int) – Dimension of the hidden state.
- output_size (int) – Dimension of the output size.
- att_heads (int) – Number of attention heads.
- dropout (float) – Dropout ratio. Default is 0.
- activation (str) – Activation function applied at the output of RNN.
- num_layers (int) – Number of stacked RNN layers. Default is 1.
- bidirectional (bool) – Whether the RNN layers are bidirectional. Default is True.
- norm_type (str) – Type of normalization to use after each inter- or intra-chunk Transformer block.
########### Examples
>>> model = DPTNet(
... rnn_type='LSTM',
... input_size=256,
... hidden_size=128,
... output_size=256,
... att_heads=4,
... dropout=0.1,
... activation='relu',
... num_layers=2,
... bidirectional=True,
... norm_type='gLN'
... )
>>> input_tensor = torch.randn(8, 10, 256, 5) # Batch of 8
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
torch.Size([8, 256, 10, 5]) # Output shape
######## NOTE The input tensor must be of shape (batch, N, dim1, dim2).
- Raises:AssertionError – If the provided rnn_type is not one of ‘RNN’, ‘LSTM’, or ‘GRU’.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Perform the forward pass of the DPTNet model.
This method processes the input tensor through the dual-path transformer network. It first applies the transformer on the first dimension and then on the second dimension, resulting in the output tensor.
- Parameters:input (torch.Tensor) – Input tensor of shape (batch, N, dim1, dim2), where batch is the batch size, N is the number of features, dim1 is the first dimension, and dim2 is the second dimension.
- Returns: Output tensor of shape (batch, output_size, dim1, dim2), : where output_size is the dimension of the output size.
- Return type: torch.Tensor
########### Examples
>>> model = DPTNet(rnn_type='LSTM', input_size=256, hidden_size=128,
... output_size=10)
>>> input_tensor = torch.randn(32, 64, 256, 128) # Example input
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
torch.Size([32, 10, 64, 128])
######## NOTE The input tensor must have dimensions that match the expected shape. The model applies the intra-chunk and inter-chunk processes in a loop for the specified number of layers.
inter_chunk_process(x, layer_index)
Process the output from the intra-chunk transformer layer and apply the
inter-chunk transformer layer.
This method reshapes the input tensor to allow processing across chunks using the column transformer defined in the DPTNet architecture.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (batch, N, chunk_size, n_chunks), where batch is the batch size, N is the feature dimension, chunk_size is the size of each chunk, and n_chunks is the number of chunks.
- layer_index (int) – The index of the current layer being processed.
- Returns: Output tensor of shape (batch, N, chunk_size, n_chunks) : after applying the column transformer.
- Return type: torch.Tensor
########### Examples
>>> import torch
>>> model = DPTNet('LSTM', 128, 64, 32)
>>> input_tensor = torch.randn(16, 128, 10, 5) # batch_size=16
>>> output_tensor = model.inter_chunk_process(input_tensor, 0)
>>> output_tensor.shape
torch.Size([16, 32, 10, 5])
######## NOTE The input tensor is expected to be in the format (batch, N, chunk_size, n_chunks) prior to calling this method.
intra_chunk_process(x, layer_index)
Processes input tensors through the intra-chunk transformer layer.
This method reshapes the input tensor for the specified layer index, applies the intra-chunk transformer processing, and reshapes the output back to its original dimensions.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (batch, N, chunk_size, n_chunks).
- layer_index (int) – Index of the transformer layer to be used for processing.
- Returns: Transformed tensor of shape (batch, N, chunk_size, n_chunks).
- Return type: torch.Tensor
########### Examples
>>> model = DPTNet('LSTM', 128, 64, 32)
>>> input_tensor = torch.randn(10, 16, 32, 4) # Example input
>>> output_tensor = model.intra_chunk_process(input_tensor, 0)
>>> output_tensor.shape
torch.Size([10, 32, 32, 4]) # Example output shape after processing
######## NOTE The input tensor is expected to have four dimensions corresponding to batch size, number of features, chunk size, and number of chunks.