espnet2.train.abs_espnet_model.AbsESPnetModel
espnet2.train.abs_espnet_model.AbsESPnetModel
class espnet2.train.abs_espnet_model.AbsESPnetModel(*args, **kwargs)
Bases: Module
, ABC
The common abstract class among each task.
This class, AbsESPnetModel, is an abstract base class that inherits from torch.nn.Module. It serves as a blueprint for creating deep neural network models specific to various tasks within the ESPnet framework. It employs a delegate pattern to manage the forward pass of the model and defines key components such as “loss”, “stats”, and “weight” for the associated task.
When implementing a new task in ESPnet, it is essential to inherit from this class. The interaction between the training system and your task class is mediated through the loss, stats, and weight values.
Example
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
None
- Parameters:None
- Returns: None
- Yields: None
- Raises:
- NotImplementedError – If the abstract methods are not implemented in a
- subclass. –
NOTE
This class is meant to be subclassed, and cannot be instantiated directly.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract collect_feats(**batch: Tensor) → Dict[str, Tensor]
Collect features from the given batch of input tensors.
This method processes the input tensors and extracts relevant features that can be used for further computations or model training. The exact implementation of feature extraction will depend on the specific model derived from this abstract class.
- Parameters:**batch – Variable length keyword arguments representing input tensors.
- Returns: A dictionary containing the extracted features, where keys are feature names and values are the corresponding tensors.
- Raises:
- NotImplementedError – If the method is not implemented in the derived
- class. –
####
Example
>>> model = YourESPnetModel()
>>> features = model.collect_feats(input_tensor=input_data)
>>> print(features.keys()) # Should print the feature names
abstract forward(**batch: Tensor) → Tuple[Tensor, Dict[str, Tensor], Tensor]
The forward method that defines the computation performed at every call.
This method must be implemented by subclasses of AbsESPnetModel. It takes a variable number of tensor inputs and returns a tuple consisting of the computed loss, statistics, and weight for the task.
- Parameters:**batch – A variable number of keyword arguments containing tensors that represent the input data for the model.
- Returns:
- A tensor representing the computed loss.
- A dictionary containing statistics relevant to the task.
- A tensor representing the weight associated with the loss.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:NotImplementedError – If the method is not implemented by the subclass.
####
Example
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... # Implement your forward logic here
... return loss, stats, weight