espnet2.enh.layers.uses.ATFBlock
espnet2.enh.layers.uses.ATFBlock
class espnet2.enh.layers.uses.ATFBlock(input_size, rnn_type='lstm', hidden_size=128, att_heads=4, dropout=0.0, activation='relu', bidirectional=True, norm_type='cLN', ch_mode='att', ch_att_dim=256, eps=1e-05, with_channel_modeling=True)
Bases: Module
ATFBlock is a container module for a single Attentive Time-Frequency Block.
This block is designed to process time-frequency representations of audio signals using a combination of improved transformer layers and channel modeling techniques.
input_size
Dimension of the input feature.
- Type: int
rnn_type
Type of the RNN cell in the improved Transformer layer.
- Type: str
hidden_size
Hidden dimension of the RNN cell.
- Type: int
att_heads
Number of attention heads in the Transformer.
- Type: int
dropout
Dropout ratio. Default is 0.
- Type: float
activation
Non-linear activation function applied in each block.
- Type: str
bidirectional
Whether the RNN layers are bidirectional.
- Type: bool
norm_type
Normalization type in the improved Transformer layer.
- Type: str
ch_mode
Mode of channel modeling. Select from “att” and “tac”.
- Type: str
ch_att_dim
Dimension of the channel attention.
- Type: int
eps
Epsilon for layer normalization.
- Type: float
with_channel_modeling
Whether to use channel modeling.
Type: bool
Parameters:
- input_size (int) – Dimension of the input feature.
- rnn_type (str) – Type of the RNN cell in the improved Transformer layer.
- hidden_size (int) – Hidden dimension of the RNN cell.
- att_heads (int) – Number of attention heads in Transformer.
- dropout (float) – Dropout ratio. Default is 0.
- activation (str) – Non-linear activation function applied in each block.
- bidirectional (bool) – Whether the RNN layers are bidirectional.
- norm_type (str) – Normalization type in the improved Transformer layer.
- ch_mode (str) – Mode of channel modeling. Select from “att” and “tac”.
- ch_att_dim (int) – Dimension of the channel attention.
- eps (float) – Epsilon for layer normalization.
- with_channel_modeling (bool) – Whether to use channel modeling.
########### Examples
>>> atf_block = ATFBlock(input_size=64, hidden_size=128)
>>> input_tensor = torch.randn(32, 2, 64, 128, 256) # (batch, C, N, F, T)
>>> output_tensor = atf_block(input_tensor)
>>> output_tensor.shape
torch.Size([32, 2, 64, 128, 256])
- Returns: Output sequence (batch, C, N, freq, time).
- Return type: output (torch.Tensor)
- Raises:NotImplementedError – If an unsupported channel modeling mode is specified.
Container module for a single Attentive Time-Frequency Block.
- Parameters:
- input_size (int) – dimension of the input feature.
- rnn_type (str) – type of the RNN cell in the improved Transformer layer.
- hidden_size (int) – hidden dimension of the RNN cell.
- att_heads (int) – number of attention heads in Transformer.
- dropout (float) – dropout ratio. Default is 0.
- activation (str) – non-linear activation function applied in each block.
- bidirectional (bool) – whether the RNN layers are bidirectional.
- norm_type (str) – normalization type in the improved Transformer layer.
- ch_mode (str) – mode of channel modeling. Select from “att” and “tac”.
- ch_att_dim (int) – dimension of the channel attention.
- eps (float) – epsilon for layer normalization.
- with_channel_modeling (bool) – whether to use channel modeling.
forward(input, ref_channel=None)
Processes the input through the USES network.
- Parameters:
- input (torch.Tensor) – Input feature tensor of shape (batch, mics, input_size, freq, time).
- ref_channel (None or int) – Index of the reference channel. If None, all channels are averaged. If an int, the specified channel is used instead of averaging.
- mem_idx (None or int) – Index of the memory token group. If None, the only group of memory tokens in the model is used. If an int, the specified group from multiple existing groups is used.
- Returns: Output feature tensor of shape : (batch, output_size, freq, time).
- Return type: output (torch.Tensor)
########### Examples
>>> model = USES(input_size=128, output_size=64)
>>> input_tensor = torch.randn(8, 2, 128, 10, 20) # (batch, mics, input_size, freq, time)
>>> output = model(input_tensor, ref_channel=0)
>>> print(output.shape) # (8, 64, 10, 20)
NOTE
This method is designed for efficient segment-by-segment processing of input features, utilizing memory tokens to store historical information from previous segments.
freq_path_process(x)
Processes the input tensor through the frequency path of the model.
This method reshapes and permutes the input tensor to fit the expected input shape for the frequency neural network, applies the frequency neural network, and then reshapes the output back to the original dimensions.
- Parameters:x (torch.Tensor) – Input tensor of shape (batch, N, freq, time), where batch is the batch size, N is the number of features, freq is the number of frequency bins, and time is the number of time steps.
- Returns: Output tensor of shape (batch, C, freq, time), : where C is the number of channels.
- Return type: torch.Tensor
########### Examples
>>> model = ATFBlock(input_size=128)
>>> input_tensor = torch.randn(32, 10, 64, 100) # (batch, N, freq, time)
>>> output_tensor = model.freq_path_process(input_tensor)
>>> output_tensor.shape
torch.Size([32, 10, 64, 100])
time_path_process(x)
Processes the input tensor through the temporal path.
This method takes the input tensor, permutes its dimensions to prepare it for processing through a temporal neural network, and reshapes it back to its original dimensions after processing.
- Parameters:x (torch.Tensor) – Input tensor with shape (batch, N, freq, time).
- Returns: Output tensor with the same shape as the input tensor, processed through the temporal neural network.
- Return type: torch.Tensor
########### Examples
>>> model = ATFBlock(input_size=64)
>>> input_tensor = torch.randn(32, 10, 64, 128) # (batch, N, freq, time)
>>> output_tensor = model.time_path_process(input_tensor)
>>> output_tensor.shape
torch.Size([32, 10, 64, 128])