espnet2.asr.encoder.beats_encoder.BeatsConfig
espnet2.asr.encoder.beats_encoder.BeatsConfig
class espnet2.asr.encoder.beats_encoder.BeatsConfig(cfg=None)
Bases: object
Configuration class for the BEATs encoder model.
This class defines the various hyperparameters and configuration options for the BEATs model used in audio pre-training with acoustic tokenizers. The default values are set in the constructor, but they can be updated using the update method.
input_patch_size
Patch size for patch embedding.
- Type: int
embed_dim
Dimension of patch embedding.
- Type: int
conv_bias
Whether to include bias in the convolutional encoder.
- Type: bool
encoder_layers
Number of encoder layers in the transformer.
- Type: int
encoder_embed_dim
Encoder embedding dimension.
- Type: int
encoder_ffn_embed_dim
Feed-forward network embedding dimension.
- Type: int
encoder_attention_heads
Number of attention heads in the encoder.
- Type: int
activation_fn
Activation function used in the model.
- Type: str
layer_wise_gradient_decay_ratio
Ratio for layer-wise gradient decay.
- Type: float
layer_norm_first
Whether to apply layer normalization first.
- Type: bool
deep_norm
Whether to apply deep normalization first.
- Type: bool
dropout
Dropout probability for the transformer.
- Type: float
attention_dropout
Dropout probability for attention weights.
- Type: float
activation_dropout
Dropout probability after activation in FFN.
- Type: float
encoder_layerdrop
Probability of dropping a transformer layer.
- Type: float
dropout
Dropout to apply to the input after feature extraction.
- Type: float
conv_pos
Number of filters for convolutional positional embeddings.
- Type: int
conv_pos
Number of groups for convolutional positional embedding.
- Type: int
relative_position_embedding
Whether to apply relative position embedding.
- Type: bool
num_buckets
Number of buckets for relative position embedding.
- Type: int
max_distance
Maximum distance for relative position embedding.
- Type: int
gru_rel_pos
Whether to apply gated relative position embedding.
- Type: bool
finetuned_model
Indicates if the model is fine-tuned.
- Type: bool
predictor_dropout
Dropout probability for the predictor.
- Type: float
predictor_class
Target class number for the predictor.
Type: int
Parameters:cfg (dict , optional) – Configuration dictionary to update the default values.
####### Examples
Create a default configuration
config = BeatsConfig()
Create a configuration with custom settings
custom_config = BeatsConfig(cfg={
‘input_patch_size’: 32, ‘dropout’: 0.2, ‘finetuned_model’: True
})
update(cfg: dict)
Update the configuration of the BeatsConfig instance.
This method updates the attributes of the BeatsConfig instance with values from the provided configuration dictionary. It modifies the instance’s internal state directly by updating its __dict__ attribute.
- Parameters:cfg (dict) – A dictionary containing configuration parameters where keys correspond to attribute names and values are the new values to set.
####### Examples
>>> config = BeatsConfig()
>>> new_cfg = {
... 'input_patch_size': 32,
... 'dropout': 0.2,
... }
>>> config.update(new_cfg)
>>> print(config.input_patch_size)
32
>>> print(config.dropout)
0.2
NOTE
Ensure that the keys in the cfg dictionary match the attribute names of the BeatsConfig class to avoid any unexpected behavior.