espnet2.diar.espnet_model.ESPnetDiarizationModel
espnet2.diar.espnet_model.ESPnetDiarizationModel
class espnet2.diar.espnet_model.ESPnetDiarizationModel(frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, label_aggregator: Module, encoder: AbsEncoder, decoder: AbsDecoder, attractor: AbsAttractor | None, diar_weight: float = 1.0, attractor_weight: float = 1.0)
Bases: AbsESPnetModel
ESPnetDiarizationModel is a speaker diarization model that utilizes various components such as encoders, decoders, and attractors to process speech data. Depending on the presence of an attractor, it can implement either SA-EEND or EEND-EDA methods for diarization.
For more details on the methodologies used, refer to the following papers:
- SA-EEND: https://arxiv.org/pdf/1909.06247.pdf
- EEND-EDA: https://arxiv.org/pdf/2005.09921.pdf, https://arxiv.org/pdf/2106.10654.pdf
encoder
The encoder component used for processing input speech.
- Type:AbsEncoder
normalize
The normalization component for features.
- Type: Optional[AbsNormalize]
frontend
The frontend feature extractor.
- Type: Optional[AbsFrontend]
specaug
The data augmentation component.
- Type: Optional[AbsSpecAug]
label_aggregator
Aggregates speaker labels.
- Type: torch.nn.Module
diar_weight
Weight for the diarization loss.
- Type: float
attractor_weight
Weight for the attractor loss.
- Type: float
attractor
The attractor component for EEND-EDA.
- Type: Optional[AbsAttractor]
decoder
The decoder component used for predictions.
Type: Optional[AbsDecoder]
Parameters:
- frontend (Optional [AbsFrontend ]) – The frontend feature extractor.
- specaug (Optional [AbsSpecAug ]) – The spec augmentation module.
- normalize (Optional [AbsNormalize ]) – The normalization module.
- label_aggregator (torch.nn.Module) – Module to aggregate speaker labels.
- encoder (AbsEncoder) – The encoder module.
- decoder (AbsDecoder) – The decoder module.
- attractor (Optional [AbsAttractor ]) – The attractor module.
- diar_weight (float) – Weight for the diarization loss (default: 1.0).
- attractor_weight (float) – Weight for the attractor loss (default: 1.0).
##################### Examples
model = ESPnetDiarizationModel( : frontend=my_frontend, specaug=my_specaug, normalize=my_normalize, label_aggregator=my_label_aggregator, encoder=my_encoder, decoder=my_decoder, attractor=my_attractor, diar_weight=1.0, attractor_weight=1.0,
) loss, stats, weight = model(speech_data, speech_lengths, speaker_labels)
- Raises:NotImplementedError – If both attractor and decoder are None.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
attractor
Calculate the attractor loss based on the attractor probabilities.
The attractor loss is computed using binary cross-entropy loss between the predicted attractor probabilities and the ground truth attractor labels. The ground truth labels are created such that all speakers are labeled as present (1) except for an additional label which is marked as absent (0).
- Parameters:
- att_prob (torch.Tensor) – The predicted attractor probabilities of shape (Batch, num_spk + 1, 1).
- label (torch.Tensor) – The ground truth labels of shape (Batch, num_spk, 1).
- Returns: The computed attractor loss as a scalar tensor.
- Return type: torch.Tensor
##################### Examples
>>> model = ESPnetDiarizationModel(...)
>>> att_prob = torch.tensor([[0.9], [0.2], [0.8]])
>>> label = torch.tensor([[1], [0]])
>>> loss = model.attractor_loss(att_prob, label)
>>> print(loss)
static calc_diarization_error(pred, label, length)
Calculate the diarization error for predicted and true labels.
This method computes various metrics to evaluate the performance of speaker diarization predictions. It calculates the speech activity detection error, speaker miss, false alarm, and overall speaker diarization error based on the predicted and ground truth labels.
- Parameters:
- pred (torch.Tensor) – The predicted labels with shape (batch_size, max_len, num_output).
- label (torch.Tensor) – The ground truth labels with shape (batch_size, max_len, num_output).
- length (torch.Tensor) – A tensor containing the actual lengths of each sequence in the batch.
- Returns: A tuple containing the following metrics:
- correct: The number of correctly predicted frames.
- num_frames: The total number of frames.
- speech_scored: The number of frames with detected speech.
- speech_miss: The number of missed speech frames.
- speech_falarm: The number of false alarm speech frames.
- speaker_scored: The number of scored speakers.
- speaker_miss: The number of missed speakers.
- speaker_falarm: The number of false alarm speakers.
- speaker_error: The total speaker error.
- Return type: Tuple[float, float, float, float, float, float, float, float, float]
########## NOTE This method credits the implementation to the EEND project (https://github.com/hitachi-speech/EEND).
##################### Examples
>>> pred = torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]])
>>> label = torch.tensor([[[1, 0], [0, 1]], [[0, 0], [1, 1]]])
>>> length = torch.tensor([2, 2])
>>> metrics = ESPnetDiarizationModel.calc_diarization_error(pred, label, length)
>>> print(metrics)
(correct, num_frames, speech_scored, speech_miss,
speech_falarm, speaker_scored, speaker_miss,
speaker_falarm, speaker_error)
collect_feats(speech: Tensor, speech_lengths: Tensor, spk_labels: Tensor | None = None, spk_labels_lengths: Tensor | None = None, **kwargs) → Dict[str, Tensor]
Collects features from the input speech signal.
This method extracts features from the provided speech input and returns them along with their lengths. It can also handle speaker labels and their lengths, although they are not mandatory for this operation.
- Parameters:
- speech (torch.Tensor) – The input speech signal of shape (Batch, Length).
- speech_lengths (torch.Tensor) – A tensor indicating the lengths of each speech signal in the batch, of shape (Batch,).
- spk_labels (torch.Tensor , optional) – A tensor containing speaker labels of shape (Batch, …). Defaults to None.
- spk_labels_lengths (torch.Tensor , optional) – A tensor containing the lengths of speaker labels of shape (Batch, …). Defaults to None.
- **kwargs – Additional keyword arguments for future extensions.
- Returns: A dictionary containing: : - ’feats’: Extracted features of shape (Batch, NFrames, Dim).
- ’feats_lengths’: Lengths of the extracted features of shape (Batch,).
- Return type: Dict[str, torch.Tensor]
##################### Examples
>>> model = ESPnetDiarizationModel(...)
>>> speech = torch.randn(2, 16000) # Batch of 2, 16000 samples
>>> speech_lengths = torch.tensor([16000, 16000])
>>> features = model.collect_feats(speech, speech_lengths)
>>> print(features['feats'].shape) # Expected shape: (2, NFrames, Dim)
########## NOTE This function primarily uses the frontend module to process the input speech and generate features. If no frontend is specified, the raw speech input will be returned as features.
create_length_mask(length, max_len, num_output)
Creates a length mask tensor for the given input lengths, which is useful in ensuring that only valid entries in the label tensor are considered during loss computation.
This function generates a mask of shape (batch_size, max_len, num_output), where max_len is the maximum length of the sequences in the batch, and num_output is the number of output channels. The mask is populated with ones for the valid lengths and zeros for the padded lengths.
- Parameters:
- length (torch.Tensor) – A 1D tensor containing the lengths of each sequence in the batch.
- max_len (int) – The maximum length of the sequences to be considered.
- num_output (int) – The number of output channels.
- Returns: A tensor of shape (batch_size, max_len, num_output) : containing the length mask.
- Return type: torch.Tensor
##################### Examples
>>> length = torch.tensor([3, 5, 2])
>>> max_len = 5
>>> num_output = 4
>>> mask = create_length_mask(length, max_len, num_output)
>>> print(mask)
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]],
[[1., 1., 1., 1.], [1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]])
########## NOTE This function assumes that the input tensor length is a 1D tensor containing the lengths of the sequences in the batch.
encode(speech: Tensor, speech_lengths: Tensor, bottleneck_feats: Tensor, bottleneck_feats_lengths: Tensor) → Tuple[Tensor, Tensor]
- Parameters:
- speech (torch.Tensor) – Input tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – Lengths of the input sequences, shape (Batch,).
- bottleneck_feats (torch.Tensor) – Optional tensor for enhancement and diarization, shape (Batch, Length, …).
- bottleneck_feats_lengths (torch.Tensor) – Lengths of the bottleneck features, shape (Batch,).
- Returns: A tuple containing: : - encoder_out (torch.Tensor): Output tensor from the encoder, shape (Batch, Length2, Dim).
- encoder_out_lens (torch.Tensor): Lengths of the output sequences, shape (Batch,).
- Return type: Tuple[torch.Tensor, torch.Tensor]
########## NOTE The autocast context is used to enable mixed precision training if available.
##################### Examples
>>> model = ESPnetDiarizationModel(...)
>>> speech = torch.randn(32, 16000) # Example batch of 32 audio samples
>>> speech_lengths = torch.tensor([16000] * 32) # Lengths of each sample
>>> bottleneck_feats = torch.randn(32, 100, 40) # Example bottleneck features
>>> bottleneck_feats_lengths = torch.tensor([100] * 32) # Lengths of bottleneck features
>>> encoder_out, encoder_out_lens = model.encode(speech, speech_lengths,
... bottleneck_feats, bottleneck_feats_lengths)
forward(speech: Tensor, speech_lengths: Tensor | None = None, spk_labels: Tensor | None = None, spk_labels_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Process input speech through the model and compute the loss.
This method combines the frontend, encoder, and decoder to calculate the diarization loss. It also computes various statistics related to speaker diarization performance.
- Parameters:
- speech – Tensor of shape (Batch, samples) representing input speech.
- speech_lengths – Optional; Tensor of shape (Batch,) indicating the lengths of each input sequence. Default is None, which is useful for chunk iterators that do not provide lengths.
- spk_labels – Tensor of shape (Batch, …) containing speaker labels.
- spk_labels_lengths – Optional; Tensor of shape (Batch,) indicating the lengths of each speaker label sequence.
- kwargs – Additional arguments; “utt_id” is among the inputs.
- Returns:
- loss: Computed loss value.
- stats: Dictionary of statistics including loss components and : diarization metrics.
- weight: Weight of the current batch.
- Return type: Tuple containing
- Raises:
- AssertionError – If the number of speech samples does not match the
- number of speaker labels. –
##################### Examples
>>> model = ESPnetDiarizationModel(...)
>>> speech = torch.randn(10, 16000) # 10 samples of 1 second audio
>>> speech_lengths = torch.tensor([16000] * 10) # All 1 second long
>>> spk_labels = torch.randint(0, 2, (10, 20, 3)) # Example labels
>>> loss, stats, weight = model.forward(speech, speech_lengths, spk_labels)
########## NOTE Ensure that the input tensors are on the same device as the model for proper computation.
pit_loss(pred, label, lengths)
Calculate the permutation-invariant training (PIT) loss.
This method computes the PIT loss for a given set of predictions and corresponding labels by considering all possible permutations of the labels. The minimum loss across all permutations is returned, along with the corresponding permutation indices and the permuted labels.
- Parameters:
- pred (torch.Tensor) – The predicted outputs with shape (Batch, Length, num_output).
- label (torch.Tensor) – The ground truth labels with shape (Batch, Length, num_output).
- lengths (torch.Tensor) – The lengths of the sequences, with shape (Batch,).
- Returns:
- loss: The calculated PIT loss.
- min_idx: The indices of the minimum loss permutation for each sample in the batch.
- permute_list: A list containing all permutations of labels.
- label_permute: The permuted labels corresponding to the minimum loss.
- Return type: Tuple[torch.Tensor, torch.Tensor, List[np.ndarray], torch.Tensor]
########## NOTE Credit to https://github.com/hitachi-speech/EEND for the implementation of this method.
##################### Examples
>>> pred = torch.rand(2, 5, 3) # Example predictions
>>> label = torch.rand(2, 5, 3) # Example labels
>>> lengths = torch.tensor([5, 5]) # Lengths of each sequence
>>> loss, min_idx, permute_list, label_permute = pit_loss(pred, label, lengths)
pit_loss_single_permute(pred, label, length)
Calculates the PIT loss for a single permutation of predictions.
This method computes the Binary Cross Entropy (BCE) loss between the predicted values and the ground truth labels for a given permutation. It also applies a length mask to ignore padding in the labels.
- Parameters:
- pred (torch.Tensor) – The predicted values of shape (Batch, Time, Output).
- label (torch.Tensor) – The ground truth labels of shape (Batch, Time, Output).
- length (torch.Tensor) – The lengths of each sequence in the batch.
- Returns: The calculated loss for the given permutation, : of shape (Batch, 1).
- Return type: torch.Tensor
##################### Examples
>>> pred = torch.tensor([[[0.1, 0.2], [0.5, 0.6]],
... [[0.3, 0.4], [0.7, 0.8]]])
>>> label = torch.tensor([[[1, 0], [0, 1]],
... [[0, 1], [1, 0]]])
>>> length = torch.tensor([2, 2])
>>> loss = pit_loss_single_permute(pred, label, length)
>>> print(loss)
tensor([[0.3567], [0.6124]])