espnet2.enh.layers.tcn.ChannelwiseLayerNorm
espnet2.enh.layers.tcn.ChannelwiseLayerNorm
class espnet2.enh.layers.tcn.ChannelwiseLayerNorm(channel_size, shape='BDT')
Bases: Module
Channel-wise Layer Normalization (cLN).
This module applies layer normalization across the channel dimension for each instance in the batch, which helps stabilize the learning process and can improve convergence.
gamma
Scale parameter for normalization.
beta
Shift parameter for normalization.
shape
Specifies the input shape format. It can be either “BDT” (Batch, Depth, Time) or “BTD” (Batch, Time, Depth).
- Parameters:
- channel_size (int) – Number of channels for normalization.
- shape (str , optional) – Input shape format. Default is “BDT”. Acceptable values are “BDT” and “BTD”.
reset_parameters()
Resets the parameters gamma and beta.
forward()
Applies the channel-wise layer normalization.
######### Examples
>>> layer_norm = ChannelwiseLayerNorm(channel_size=64)
>>> input_tensor = torch.randn(10, 64, 100) # [M, N, K]
>>> output_tensor = layer_norm(input_tensor)
>>> output_tensor.shape
torch.Size([10, 64, 100])
NOTE
The input tensor should have three dimensions: [M, N, K], where M is the batch size, N is the number of channels, and K is the length of the sequence.
- Raises:AssertionError – If the input tensor does not have 3 dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(y)
Forward pass of the TemporalConvNet.
This method takes an input tensor and processes it through the network layers to estimate the masks for the input mixture.
- Parameters:mixture_w – A tensor of shape [M, N, K], where M is the batch size, N is the number of input channels, and K is the length of the input sequence.
- Returns: A tensor of shape [M, C, N, K], where C is the number of : speakers, representing the estimated masks for the input mixture.
- Return type: est_mask
- Raises:ValueError – If an unsupported mask non-linear function is specified.
######### Examples
>>> model = TemporalConvNet(N=64, B=32, H=128, P=3, X=8, R=3, C=2)
>>> mixture = torch.randn(10, 64, 100) # Batch of 10, 64 channels, length 100
>>> masks = model(mixture)
>>> print(masks.shape) # Should output: torch.Size([10, 2, 64, 100])
reset_parameters()
Reset the parameters of the ChannelwiseLayerNorm.
This method initializes the learnable parameters gamma and beta of the channel-wise layer normalization to their default values. Specifically, gamma is set to 1 and beta is set to 0. This is typically called when creating an instance of the class or when fine-tuning the model.
gamma
Scaling parameter for the normalization.
- Type: torch.Tensor
beta
Shifting parameter for the normalization.
- Type: torch.Tensor
######### Examples
>>> layer_norm = ChannelwiseLayerNorm(channel_size=10)
>>> layer_norm.gamma # Initially filled with 1
tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
>>> layer_norm.beta # Initially filled with 0
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
NOTE
It is important to reset the parameters when the model is being re-initialized or if you want to start training from scratch.