espnet2.uasr.segmenter.join_segmenter.JoinSegmenter
espnet2.uasr.segmenter.join_segmenter.JoinSegmenter
class espnet2.uasr.segmenter.join_segmenter.JoinSegmenter(cfg: Dict | None = None, subsample_rate: float = 0.25, mean_pool: str2bool = True, mean_join_pool: str2bool = False, remove_zeros: str2bool = False)
Bases: AbsSegmenter
JoinSegmenter is a segmenter that processes input tensors to produce
segmented outputs based on a join strategy. It is designed to work within the ESPnet framework for unsupervised audio segmentation.
subsampling_rate
The rate at which to subsample the input.
- Type: float
mean_pool
Whether to use mean pooling on the logits.
- Type: bool
mean_pool
Whether to use mean pooling on the joined segments.
- Type: bool
remove_zeros
Whether to remove segments that are zero-valued.
Type: bool
Parameters:
- cfg (Optional *[*Dict ]) – Configuration dictionary for the segmenter. If provided, it must contain a ‘segmentation’ key with a subkey ‘type’ set to ‘JOIN’.
- subsample_rate (float) – The rate at which to subsample the input. Default is 0.25.
- mean_pool (str2bool) – Flag to enable mean pooling. Default is True.
- mean_join_pool (str2bool) – Flag to enable mean pooling on joined segments. Default is False.
- remove_zeros (str2bool) – Flag to remove zero-valued segments. Default is False.
pre_segment(xs_pad
torch.Tensor, padding_mask: torch.Tensor) -> torch.Tensor: Prepares the input tensor and padding mask for segmentation.
logit_segment(logits
torch.Tensor, padding_mask: torch.Tensor) -> torch.Tensor: Processes the logits and padding mask to produce the segmented outputs.
######### Examples
Initialize the JoinSegmenter with default parameters
segmenter = JoinSegmenter()
Pre-segmenting inputs
xs_pad, padding_mask = segmenter.pre_segment(xs_pad_tensor, padding_mask_tensor)
Getting the segmented logits
segmented_logits, new_padding_mask = segmenter.logit_segment(logits_tensor, padding_mask_tensor)
NOTE
This segmenter assumes that the input logits are the output of a model’s prediction layer.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
#
logit_segment(logits
Logit segment function that processes logits and applies padding masks to segment
the output.
This method takes the logits and a padding mask as input, then identifies the predicted classes based on the highest logits. It handles padding by marking the corresponding predictions as -1. The function also manages the removal of zero entries based on the configuration and adjusts the logits accordingly.
remove_zeros
Flag indicating whether to remove zero entries from the logits.
Type: bool
Parameters:
- logits (torch.Tensor) – A tensor of shape (batch_size, time_length, channel_size) containing the model’s raw output logits.
- padding_mask (torch.Tensor) – A tensor of shape (batch_size, time_length) that indicates the padding positions in the logits.
Returns: A tensor of shape (batch_size, new_time_length, : channel_size) containing the processed logits after segmentation and padding adjustments.
Return type: torch.Tensor
######### Examples
>>> logits = torch.randn(2, 5, 3) # Example logits
>>> padding_mask = torch.tensor([[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]]) # Example mask
>>> segmenter = JoinSegmenter()
>>> new_logits, new_pad = segmenter.logit_segment(logits, padding_mask)
NOTE
This method is designed to work within the context of a JoinSegmenter instance and assumes that the instance has been properly initialized with configuration parameters.
#
pre_segment(xs_pad
Preprocesses input tensors for segmentation by returning the original tensors.
This method serves as a preliminary step before applying the segmentation logic. It currently returns the input tensors without modification. This is useful for ensuring compatibility with subsequent methods in the segmentation process.
- Parameters:
- xs_pad (torch.Tensor) – A tensor representing the input features to be processed. This tensor can have any shape that is compatible with the segmentation model.
- padding_mask (torch.Tensor) – A tensor indicating the padding in the input features. It should have the same shape as xs_pad, with boolean values indicating padded positions.
- Returns: The original input tensor xs_pad and the padding_mask : without any modifications.
- Return type: torch.Tensor
######### Examples
>>> import torch
>>> segmenter = JoinSegmenter()
>>> xs_pad = torch.randn(2, 5, 10) # Example input tensor
>>> padding_mask = torch.tensor([[False, False, True, True, True],
... [False, False, False, True, True]])
>>> processed_xs, processed_mask = segmenter.pre_segment(xs_pad, padding_mask)
>>> assert torch.equal(xs_pad, processed_xs)
>>> assert torch.equal(padding_mask, processed_mask)