espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper
espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper
class espnet2.enh.layers.dcunet.ArgsComplexMultiplicationWrapper(module_cls, *args, **kwargs)
Bases: Module
Adapted from asteroid’s complex_nn.py, allowing
args/kwargs to be passed through forward().
Make a complex-valued module F from a real-valued module f by applying complex multiplication rules:
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
where f1, f2 are instances of f that do not share weights.
- Parameters:module_cls (callable) – A class or function that returns a Torch module/functional. Constructor of f in the formula above. Called 2x with *args, **kwargs, to construct the real and imaginary component modules.
####### Examples
>>> from torch import nn
>>> wrapper = ArgsComplexMultiplicationWrapper(nn.Linear, 10, 5)
>>> input_tensor = torch.randn(2, 10) + 1j * torch.randn(2, 10)
>>> output = wrapper(input_tensor)
>>> print(output.shape) # Output shape depends on the module used.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, *args, **kwargs)
Wraps a real-valued module to enable complex multiplication rules.
This module adapts a real-valued function f into a complex-valued module F by applying complex multiplication rules defined as follows:
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
where f1 and f2 are instances of f that do not share weights.
- Parameters:module_cls (callable) – A class or function that returns a Torch module/functional. Constructor of f in the formula above. Called twice with *args, **kwargs to construct the real and imaginary component modules.
re_module
Module for the real part of the complex input.
- Type: nn.Module
im_module
Module for the imaginary part of the complex input.
- Type: nn.Module
####### Examples
>>> import torch
>>> from torch import nn
>>> complex_module = ArgsComplexMultiplicationWrapper(nn.Linear, 10, 5)
>>> input_tensor = torch.randn(2, 10) + 1j * torch.randn(2, 10)
>>> output = complex_module(input_tensor)
>>> print(output.shape) # Should reflect the output shape of the module
- Raises:NotImplementedError – If the wrapped module does not support complex operations.