espnet2.enh.layers.uses2_swin.USES2_Swin
espnet2.enh.layers.uses2_swin.USES2_Swin
class espnet2.enh.layers.uses2_swin.USES2_Swin(input_size, output_size, bottleneck_size=64, num_blocks=3, num_spatial_blocks=2, swin_block_depth=(4, 4, 4, 4), input_resolution=(130, 256), window_size=(10, 8), mlp_ratio=4, qkv_bias=True, qk_scale=None, att_heads=4, dropout=0.0, att_dropout=0.0, drop_path=0.0, activation='relu', use_checkpoint=False, ch_mode='att_tac', ch_att_dim=256, eps=1e-05)
Bases: Module
Unconstrained Speech Enhancement and Separation v2 (USES2-Swin) Network.
Reference: : [1] W. Zhang, J.-w. Jung, and Y. Qian, βImproving Design of Input : Condition Invariant Speech Enhancement,β in Proc. ICASSP, 2024. <br/> [2] W. Zhang, K. Saijo, Z.-Q., Wang, S. Watanabe, and Y. Qian, : βToward Universal Speech Enhancement for Diverse Input Conditions,β in Proc. ASRU, 2023.
- Parameters:
- input_size (int) β dimension of the input feature.
- output_size (int) β dimension of the output.
- bottleneck_size (int) β dimension of the bottleneck feature. Must be a multiple of att_heads.
- num_blocks (int) β number of ResSwinBlock blocks.
- num_spatial_blocks (int) β number of ResSwinBlock blocks with channel modeling.
- swin_block_depth (Tuple *[*int ]) β depth of each ResSwinBlock.
- input_resolution (tuple) β frequency and time dimension of the input feature. Only used for efficient training. Should be close to the actual spectrum size (F, T) of training samples.
- window_size (tuple) β size of the Time-Frequency window in Swin-Transformer.
- mlp_ratio (int) β ratio of the MLP hidden size to embedding size in BasicLayer.
- qkv_bias (bool) β If True, add a learnable bias to query, key, value in BasicLayer.
- qk_scale (float) β Override default qk scale of head_dim ** -0.5 in BasicLayer if set.
- att_heads (int) β number of attention heads in Transformer.
- dropout (float) β dropout ratio in BasicLayer. Default is 0.
- att_dropout (float) β attention dropout ratio in BasicLayer.
- drop_path (float) β drop-path ratio in BasicLayer.
- activation (str) β non-linear activation function applied in each block.
- use_checkpoint (bool) β whether to use checkpointing to save memory.
- ch_mode (str) β mode of channel modeling. Select from βattβ, βtacβ, and βatt_tacβ
- ch_att_dim (int) β dimension of the channel attention.
- eps (float) β epsilon for layer normalization.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, ref_channel=None)
USES2-Swin forward.
- Parameters:
- input (torch.Tensor) β input feature (batch, mics, input_size, freq, time)
- ref_channel (None or int) β index of the reference channel.
- Returns: output feature (batch, output_size, freq, time)
- Return type: output (torch.Tensor)
