espnet2.enh.layers.dcunet.ComplexLinear
espnet2.enh.layers.dcunet.ComplexLinear
class espnet2.enh.layers.dcunet.ComplexLinear(input_dim, output_dim, complex_valued)
Bases: Module
A potentially complex-valued linear layer.
This layer can operate in a complex-valued space when complex_valued=True. If complex_valued=False, it behaves as a standard linear layer. The layer performs linear transformations on complex input tensors by separately processing the real and imaginary parts.
complex_valued
Indicates if the layer operates in complex-valued space.
- Type: bool
re
Linear layer for the real part.
- Type: nn.Linear
im
Linear layer for the imaginary part.
- Type: nn.Linear
lin
Linear layer when operating in real-valued space.
Type: nn.Linear
Parameters:
- input_dim (int) – The number of input features.
- output_dim (int) – The number of output features.
- complex_valued (bool) – Flag to determine if the layer is complex-valued.
Returns: The transformed output tensor.
Return type: Tensor
####### Examples
>>> # Example for complex-valued operation
>>> layer = ComplexLinear(input_dim=4, output_dim=2, complex_valued=True)
>>> input_tensor = torch.randn(10, 4) + 1j * torch.randn(10, 4)
>>> output_tensor = layer(input_tensor)
>>> print(output_tensor.shape) # Output: torch.Size([10, 2])
>>> # Example for real-valued operation
>>> layer = ComplexLinear(input_dim=4, output_dim=2, complex_valued=False)
>>> input_tensor = torch.randn(10, 4)
>>> output_tensor = layer(input_tensor)
>>> print(output_tensor.shape) # Output: torch.Size([10, 2])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the ComplexLinear layer.
This method computes the output of the linear layer. If the layer is complex-valued, it applies the complex multiplication rules defined in the class. If it is not complex-valued, it simply performs the standard linear transformation.
- Parameters:x (Tensor) – Input tensor, which can be complex-valued. It should have the shape (batch_size, input_dim) for real-valued or (batch_size, input_dim, 2) for complex-valued inputs.
- Returns: The output tensor after applying the linear transformation. : The output shape will be (batch_size, output_dim) for real-valued or (batch_size, output_dim, 2) for complex-valued outputs.
- Return type: Tensor
####### Examples
>>> import torch
>>> linear_layer = ComplexLinear(3, 2, complex_valued=True)
>>> input_tensor = torch.randn(4, 3) + 1j * torch.randn(4, 3)
>>> output = linear_layer(input_tensor)
>>> print(output.shape) # Should print torch.Size([4, 2])
>>> linear_layer_real = ComplexLinear(3, 2, complex_valued=False)
>>> input_tensor_real = torch.randn(4, 3)
>>> output_real = linear_layer_real(input_tensor_real)
>>> print(output_real.shape) # Should print torch.Size([4, 2])