espnet2.enh.layers.ncsnpp_utils.layers.NIN
espnet2.enh.layers.ncsnpp_utils.layers.NIN
class espnet2.enh.layers.ncsnpp_utils.layers.NIN(in_dim, num_units, init_scale=0.1)
Bases: Module
NIN (Network in Network) layer.
This layer applies a Network in Network (NIN) operation, which is a form of deep learning layer designed to learn more complex features by using multiple small convolutional layers instead of traditional convolutional layers.
W
Weight parameter of the layer, initialized using a default scale.
- Type: torch.nn.Parameter
b
Bias parameter of the layer, initialized to zero.
Type: torch.nn.Parameter
Parameters:
- in_dim (int) – The number of input dimensions (channels).
- num_units (int) – The number of output units (channels).
- init_scale (float) – The initial scale for weight initialization. Default is 0.1.
Returns: The output tensor after applying the NIN operation.
Return type: torch.Tensor
####### Examples
>>> nin_layer = NIN(in_dim=64, num_units=128)
>>> input_tensor = torch.randn(32, 64, 8, 8) # (batch_size, channels, height, width)
>>> output_tensor = nin_layer(input_tensor)
>>> print(output_tensor.shape) # (32, 128, 8, 8)
NOTE
The input tensor is permuted before the computation to align with the expected dimensions for matrix multiplication.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
The NIN class implements a Network in Network (NIN) layer for deep learning.
This class is a custom PyTorch module that applies a linear transformation to the input tensor using learnable parameters. The input tensor is expected to be in the shape (batch_size, channels, height, width) and is permuted to (batch_size, height, width, channels) before the linear transformation is applied. The output is then permuted back to the original shape.
W
Learnable weight parameter of the layer.
- Type: torch.nn.Parameter
b
Learnable bias parameter of the layer.
Type: torch.nn.Parameter
Parameters:
- in_dim (int) – The number of input channels.
- num_units (int) – The number of output channels (units).
- init_scale (float) – Scale for weight initialization. Default is 0.1.
Returns: The output tensor after applying the linear transformation.
Return type: Tensor
####### Examples
>>> import torch
>>> model = NIN(in_dim=3, num_units=10)
>>> input_tensor = torch.randn(4, 3, 32, 32) # Batch of 4 images
>>> output_tensor = model(input_tensor)
>>> output_tensor.shape
torch.Size([4, 10, 32, 32]) # Output shape after NIN