espnet2.enh.diffusion.abs_diffusion.AbsDiffusion
espnet2.enh.diffusion.abs_diffusion.AbsDiffusion
class espnet2.enh.diffusion.abs_diffusion.AbsDiffusion(*args, **kwargs)
Bases: Module
, ABC
Abstract base class for diffusion models used in audio enhancement.
This class defines the interface for diffusion models, providing methods for forward propagation and audio enhancement. Derived classes must implement the abstract methods defined in this class.
None
- Parameters:None
forward(input
torch.Tensor, ilens: torch.Tensor): Abstract method for performing the forward pass.
enhance(input
torch.Tensor): Abstract method for enhancing the audio input.
- Raises:
- NotImplementedError – If the derived class does not implement the abstract
- methods. –
######### Examples
class MyDiffusionModel(AbsDiffusion): : def forward(self, input: torch.Tensor, ilens: torch.Tensor): : # Implementation of the forward method pass <br/> def enhance(self, input: torch.Tensor): : # Implementation of the enhance method pass
model = MyDiffusionModel() output = model.forward(torch.randn(1, 16000), torch.tensor([16000])) enhanced_output = model.enhance(torch.randn(1, 16000))
####### NOTE This class should not be instantiated directly.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract enhance(input: Tensor)
Abstract base class for diffusion models in the AbsDiffusion package.
This class defines the interface for diffusion models, including the forward pass and an enhancement method. Subclasses must implement the abstract methods to provide specific functionalities.
None
- Parameters:
- input (torch.Tensor) – The input tensor to the model.
- ilens (torch.Tensor) – The lengths of the input sequences.
forward(input
torch.Tensor, ilens: torch.Tensor): Defines the forward pass of the model.
enhance(input
torch.Tensor): Enhances the input tensor using the model’s specific enhancement method.
- Raises:
- NotImplementedError – If the abstract methods are not implemented
- in the subclass. –
######### Examples
class MyDiffusionModel(AbsDiffusion): : def forward(self, input, ilens): : # Implementation of the forward method pass <br/> def enhance(self, input): : # Implementation of the enhance method pass
model = MyDiffusionModel() enhanced_output = model.enhance(torch.randn(1, 3, 64, 64))
####### NOTE This class should not be instantiated directly.
abstract forward(input: Tensor, ilens: Tensor)
Computes the forward pass of the diffusion model.
This method takes an input tensor and its corresponding lengths, processing them to generate the output of the diffusion model.
- Parameters:
- input (torch.Tensor) – The input tensor containing data to be processed.
- ilens (torch.Tensor) – A tensor representing the lengths of each input sequence in the batch.
- Returns: The output tensor after applying the diffusion process : to the input.
- Return type: torch.Tensor
- Raises:NotImplementedError – If the method is not implemented in a derived class.
######### Examples
>>> model = MyDiffusionModel() # MyDiffusionModel is a subclass of AbsDiffusion
>>> input_data = torch.randn(32, 10) # Example input tensor
>>> input_lengths = torch.tensor([10] * 32) # Lengths of each sequence
>>> output = model.forward(input_data, input_lengths)
>>> print(output.shape) # Expected output shape depends on the model
####### NOTE This method must be overridden in any subclass of AbsDiffusion.