espnet2.train.spk_trainer.SpkTrainer
espnet2.train.spk_trainer.SpkTrainer
class espnet2.train.spk_trainer.SpkTrainer
Bases: Trainer
Trainer module for speaker recognition.
In speaker recognition (embedding extractor training/inference), calculating validation loss in a closed set is not informative since generalization to unseen utterances from known speakers is often good. Thus, we measure the open set equal error rate (EER) using unknown speakers by overriding the validate_one_epoch method.
None
validate_one_epoch()
Validates the model for one epoch and computes EER.
extract_embed()
Extracts speaker embeddings from the provided iterator.
######### Examples
Create an instance of the SpkTrainer (not allowed, raises error)
trainer = SpkTrainer() # Raises RuntimeError
Validate one epoch
SpkTrainer.validate_one_epoch(model, iterator, reporter, options,
distributed_option)
Extract embeddings
SpkTrainer.extract_embed(model, iterator, reporter, options,
distributed_option, output_dir, custom_bs, average)
####### NOTE This class is designed to be used as a base class and should not be instantiated directly.
classmethod extract_embed(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption, output_dir: str, custom_bs: int, average: bool = False) → None
Extract speaker embeddings from the model and save them to a file.
This method processes audio data in batches, extracts speaker embeddings using the provided model, and saves the embeddings to a specified output directory. The extraction can be done in a distributed manner if configured.
model
The neural network model used for extracting embeddings.
- Type: torch.nn.Module
iterator
An iterable that yields batches of input data.
- Type: Iterable[Dict[str, torch.Tensor]]
reporter
An instance for reporting statistics during training or evaluation.
- Type:SubReporter
options
Configuration options for the trainer.
- Type:TrainerOptions
distributed_option
Options for distributed training.
- Type:DistributedOption
output_dir
The directory where the extracted embeddings will be saved.
- Type: str
custom_bs
The batch size to use for processing input data.
- Type: int
average
If True, averages the embeddings over the time axis for each utterance.
Type: bool
Parameters:
- cls – The class reference.
- model – The model used to extract embeddings.
- iterator – An iterable providing batches of audio data.
- reporter – An instance to report metrics.
- options – Trainer configuration options.
- distributed_option – Options for distributed training.
- output_dir – Directory to save extracted embeddings.
- custom_bs – Batch size for processing.
- average – Whether to average the embeddings over time.
Returns: This method does not return any value.
Return type: None
######### Examples
>>> model = YourModelClass()
>>> iterator = your_data_iterator
>>> reporter = SubReporter()
>>> options = TrainerOptions()
>>> distributed_option = DistributedOption()
>>> output_dir = "path/to/output"
>>> custom_bs = 32
>>> SpkTrainer.extract_embed(model, iterator, reporter, options,
... distributed_option, output_dir, custom_bs)
####### NOTE Ensure that the output directory exists and is writable. The embeddings will be saved in a compressed .npz format.
classmethod validate_one_epoch(model: Module, iterator: Iterable[Dict[str, Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption) → None
Validate one epoch of the speaker recognition model.
This method evaluates the model on a validation dataset by calculating the open set equal error rate (EER) using unknown speakers. It performs inference and computes similarity scores between speaker embeddings.
- Parameters:
- model (torch.nn.Module) – The speaker recognition model to evaluate.
- iterator (Iterable *[*Dict *[*str , torch.Tensor ] ]) – An iterable that yields batches of input data for validation.
- reporter (SubReporter) – An object for logging and reporting statistics.
- options (TrainerOptions) – Configuration options for the trainer.
- distributed_option (DistributedOption) – Configuration for distributed training options.
- Returns: This function does not return a value.
- Return type: None
- Raises:ValueError – If an unexpected label is encountered in the data.
######### Examples
To use this method, you can call it with the appropriate arguments:
``
`
python trainer.validate_one_epoch(model, data_iterator, reporter, options,
distributed_option)
``
`
####### NOTE This method is decorated with @torch.no_grad() to prevent gradient calculation during validation, which saves memory and speeds up computation.