ESPnet3 Model Configuration (Training)
Less than 1 minute
ESPnet3 Model Configuration (Training)
This page explains how model and task in train.yaml map to model construction for the train / collect_stats stages.
Two modes: task (ESPnet2) vs model._target_ (custom)
Use ESPnet2-style models (task)
If you want to reuse an ESPnet2-derived model stack, set task and use an ESPnet2-style model: block.
task: espnet3.systems.asr.task.ASRTask
model:
encoder: transformer
decoder: transformer
# ...ESPnet2-style config...Tip: you can start from existing ESPnet2 configs under egs2/*/*/conf/*.yaml.
Typical ASR model: keys in ESPnet2 configs:
| Key | Purpose |
|---|---|
encoder / encoder_conf | Encoder type and settings. |
decoder / decoder_conf | Decoder type and settings. |
model / model_conf | ASR model head and loss settings (CTC/attention, etc.). |
frontend / frontend_conf | Feature extraction (e.g., STFT/FBANK). |
specaug / specaug_conf | SpecAugment settings. |
normalize / normalize_conf | Feature normalization (e.g., global MVN). |
Use custom/ESPnet3-only models (model._target_)
If you want an ESPnet3-specific or fully custom model, implement it under your recipe's src/ directory and point model._target_ to it:
model:
_target_: src.my_model.MyModel
# custom args hereTraining-time forward contract (common pattern)
For ASR-style training, the training wrapper typically expects your model to accept batch fields such as:
speech,speech_lengths,text,text_lengths
and return a tuple:
loss: scalar tensorstats: dict of scalars (logging only)weight: scalar tensor used as batch size for logging
Example:
class MyCustomModel:
def forward(self, speech, speech_lengths, text, text_lengths, **kwargs):
loss = ...
stats = {"loss": loss.detach()}
weight = speech.new_tensor(speech.shape[0])
return loss, stats, weightCollect-stats support (collect_feats)
If you want to use collect_stats, your model should implement collect_feats(). See:
- Stage doc:
doc/vuepress/src/espnet3/stages/collect-stats.md - Config doc:
doc/vuepress/src/espnet3/config/train_config.md
