espnet2.enh.layers.uses.ChannelAttention
espnet2.enh.layers.uses.ChannelAttention
class espnet2.enh.layers.uses.ChannelAttention(input_dim, att_heads=4, att_dim=256, activation='relu', eps=1e-05)
Bases: Module
Channel Attention module.
This module implements a channel attention mechanism which applies self-attention across the channel dimension. It utilizes multiple attention heads to compute attention scores and applies a linear transformation followed by a non-linear activation function and layer normalization.
att_heads
Number of attention heads in self-attention.
- Type: int
att_dim
Projection dimension for query and key before self-attention.
- Type: int
activation
Non-linear activation function.
Type: str
Parameters:
- input_dim (int) – Dimension of the input feature.
- att_heads (int) – Number of attention heads in self-attention.
- att_dim (int) – Projection dimension for query and key before self-attention.
- activation (str) – Non-linear activation function.
- eps (float) – Epsilon for layer normalization.
####### Examples
>>> import torch
>>> channel_attention = ChannelAttention(input_dim=128, att_heads=4)
>>> input_tensor = torch.randn(32, 8, 128, 64, 10) # (batch, C, N, F, T)
>>> output_tensor = channel_attention(input_tensor)
>>> output_tensor.shape
torch.Size([32, 8, 128, 64, 10])
- Returns: Output feature (batch, C, N, freq, time) after applying channel attention.
- Return type: output (torch.Tensor)
Channel Attention module.
- Parameters:
- input_dim (int) – dimension of the input feature.
- att_heads (int) – number of attention heads in self-attention.
- att_dim (int) – projection dimension for query and key before self-attention.
- activation (str) – non-linear activation function.
- eps (float) – epsilon for layer normalization.
forward(x, ref_channel=None)
Processes the input tensor through the USES network.
This method takes an input tensor and performs a forward pass through the USES architecture, which consists of various processing blocks. It can utilize memory tokens to store historical information and can process the input either by averaging channels or selecting a reference channel.
- Parameters:
- input (torch.Tensor) – Input feature tensor of shape (batch, mics, input_size, freq, time).
- ref_channel (None or int) – Index of the reference channel. If None, the output will be the average of all channels. If int, the specified channel will be used instead of averaging.
- mem_idx (None or int) – Index of the memory token group. If None, the only group of memory tokens in the model will be used. If int, the specified group from multiple existing groups will be utilized.
- Returns: Output feature tensor of shape : (batch, output_size, freq, time).
- Return type: output (torch.Tensor)
####### Examples
>>> model = USES(input_size=128, output_size=64)
>>> input_tensor = torch.randn(10, 2, 128, 64, 100)
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
torch.Size([10, 64, 64, 100])
NOTE
Ensure that the input tensor’s shape is compatible with the expected dimensions, as specified in the arguments.