ESPnet3 Tutorial: Fine-tuning on a custom dataset
About 4 min
ESPnet3 Tutorial: Fine-tuning on a custom dataset
Author: Masao Someki (msomeki@andrew.cmu.edu)
In this tutorial, we will show you how to fine-tune speech foundation models on a custom dataset using ESPnet3.
Main references:
❗Important Notes❗
- We are using Colab to show the demo. However, Colab has some constraints on the total GPU runtime. If you use too much GPU time, you may not be able to use GPU for some time.
- There are multiple in-class checkpoints ✅ throughout this tutorial. Your participation points are based on these tasks. Please try your best to follow all the steps! If you encounter issues, please notify the TAs as soon as possible so that we can make an adjustment for you.
- Please submit PDF files of your completed notebooks to Gradescope. You can print the notebook using
File -> Printin the menu bar.
Acknowledgement
- This homework is adapted from William Chen's version in 11752 (williamchen@cmu.edu)
- Eariler version is prepared by Siddhant Arora in last year's 11692 (ESPnet demo) (siddhana@andrew.cmu.edu) and CMU 11492/11692/18495 ESPnet Tutorial by Jiatong Shi.
Docs quick links:
1) Prerequisites
Environment setup
Assumes ESPnet3 and its dependencies are installed. If not, see:
- https://espnet.github.io/espnet/installation.html
!git clone https://github.com/Masao-Someki/espnet.git -b espnet3/recipe/ls_asr100_2
!cd espnet && pip install .
!pip install espnet-model-zoo # for downloading pre-trained models and configs
!pip install datasets==2.16.0 # for downloading ASR datasets # chyi: new version of datasets have issues in load_dataset
!apt install ffmpeg # for audio file processing
!pip install ipywebrtc notebook# install dependencies for ASR task
!cd espnet && pip install .[asr]2) Imports
Runtime imports
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from espnet2.bin.s2t_inference import Speech2Text3) Download FLEURS
Data ingestion
Set the FLEURS language config (e.g., en_us, ja_jp).
Related docs:
FLEURS_CONFIG = "en_us"
fleurs_hf = load_dataset("google/fleurs", FLEURS_CONFIG)(Optional) Inspect one sample
Data sanity check
from IPython.display import Audio, display
sample = fleurs_hf["train"][0]
print(sample["transcription"])
display(Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))4) Workspace and tokenizer assets
Assets & tokens
WORK_DIR = Path(os.environ.get("WORK_DIR", "./work/owsm_v4_fleurs")).resolve()
WORK_DIR.mkdir(parents=True, exist_ok=True)
MODEL_TAG = "espnet/owsm_v4_base_102M"
OWSM_LANG = "eng" # ISO3 (e.g., eng, jpn)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
s2t = Speech2Text.from_pretrained(
model_tag=MODEL_TAG,
lang_sym=OWSM_LANG,
device="cpu",
)
BPEMODEL = s2t.tokenizer.model
TOKEN_LIST = WORK_DIR / "token_list.txt"
TOKEN_LIST.write_text("\n".join(s2t.converter.token_list))
tokenizer = s2t.tokenizer
converter = s2t.converter
def tokenize(text):
return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
def detokenize(ids):
return tokenizer.tokens2text(converter.ids2tokens(ids))
del s2t5) Dataset
Custom dataset
A minimal dataset that returns the fields required by S2TPreprocessor.
Related docs:
class FLEURSDataset(Dataset):
def __init__(
self,
*,
data_lang: str,
split: str,
lang_sym: str,
task_sym: str,
) -> None:
self.lang_sym = str(lang_sym)
self.task_sym = str(task_sym)
self._ds = load_dataset(
"google/fleurs",
data_lang,
split=str(split),
)
def __len__(self) -> int:
return len(self._ds)
def __getitem__(self, idx: int):
ex = self._ds[int(idx)]
transcription = str(ex["transcription"]).strip()
audio = ex["audio"]["array"].astype(np.float32)
text = f"{self.lang_sym}{self.task_sym}<notimestamps> {transcription}"
return {
"speech": audio,
"text": text,
"text_prev": "<na>",
"text_ctc": transcription,
"text_raw": tokenize(transcription),
}6) Model wrapper
Custom model
Related docs:
class OWSMV4BaseFinetuneModel(nn.Module):
def __init__(
self,
*,
model_tag: str,
lang_sym: str,
device: str = "cpu",
) -> None:
super().__init__()
s2t = Speech2Text.from_pretrained(
model_tag=model_tag,
lang_sym=lang_sym,
device=str(device),
)
self.s2t_model = s2t.s2t_model
def forward(self, **batch):
return self.s2t_model(**batch)
def collect_feats(self, **batch):
return self.s2t_model.collect_feats(**batch)7) Training config (YAML)
Config
Edit max_epochs, batch_bins, or limit_train_batches for quick tests.
Related docs:
EXP_TAG = f"owsm_v4_base_fleurs_{FLEURS_CONFIG}"
EXP_DIR = (WORK_DIR / "exp" / EXP_TAG).as_posix()
STATS_DIR = (WORK_DIR / "exp" / "stats").as_posix()
LANG_SYM = f"<{OWSM_LANG}>"
TASK_SYM = "<asr>"
yaml_cfg = f"""
num_device: 1
num_nodes: 1
exp_tag: {EXP_TAG}
recipe_dir: .
data_dir: {WORK_DIR / 'data'}
exp_dir: {EXP_DIR}
stats_dir: {STATS_DIR}
dataset_dir: {WORK_DIR / 'hf_cache'}
dataset:
_target_: espnet3.components.data.data_organizer.DataOrganizer
train:
- name: train
dataset:
_target_: __main__.FLEURSDataset
split: train
data_lang: {FLEURS_CONFIG}
lang_sym: {LANG_SYM}
task_sym: {TASK_SYM}
valid:
- name: validation
dataset:
_target_: __main__.FLEURSDataset
split: validation
data_lang: {FLEURS_CONFIG}
lang_sym: {LANG_SYM}
task_sym: {TASK_SYM}
preprocessor:
_target_: espnet2.train.preprocessor.S2TPreprocessor
train: true
token_type: bpe
token_list: {TOKEN_LIST}
bpemodel: {BPEMODEL}
text_prev_name: text_prev
text_ctc_name: text_ctc
fs: 16000
dataloader:
collate_fn:
_target_: espnet2.train.collate_fn.CommonCollateFn
int_pad_value: -1
train:
iter_factory:
_target_: espnet2.iterators.sequence_iter_factory.SequenceIterFactory
shuffle: true
collate_fn: ${{dataloader.collate_fn}}
num_workers: 0
batches:
type: numel
shape_files:
- ${{stats_dir}}/train/feats_shape
batch_size: 2
batch_bins: 2000000
valid:
iter_factory:
_target_: espnet2.iterators.sequence_iter_factory.SequenceIterFactory
shuffle: false
collate_fn: ${{dataloader.collate_fn}}
num_workers: 0
batches:
type: numel
shape_files:
- ${{stats_dir}}/valid/feats_shape
batch_size: 2
batch_bins: 2000000
model:
_target_: __main__.OWSMV4BaseFinetuneModel
model_tag: {MODEL_TAG}
lang_sym: {LANG_SYM}
device: {DEVICE}
optim:
_target_: torch.optim.Adam
lr: 1.0e-5
scheduler:
_target_: espnet2.schedulers.warmup_lr.WarmupLR
warmup_steps: 30000
best_model_criterion:
- - valid/loss
- 3
- min
trainer:
accelerator: auto
devices: 1
max_epochs: 1
log_every_n_steps: 1
limit_train_batches: 10
fit: {{}}
"""
cfg = OmegaConf.create(yaml_cfg)
print("Config ready. exp_dir:", cfg.exp_dir)8) Train
collect_stats + train
Related docs:
from espnet3.systems.base.system import BaseSystem
from espnet3.utils.logging import configure_logging
from espnet3.utils.stages import run_stages
log = configure_logging()
system = BaseSystem(train_config=cfg)
run_stages(system, ["collect_stats", "train"], log=log)9) Inference (optional)
Stage 8: Inference
Run ASR on the test split using the fine-tuned checkpoint.
Related docs:
class OWSMV4BaseInferenceModel(nn.Module):
def __init__(
self,
*,
model_tag: str,
lang_sym: str,
checkpoint_path: str,
device: str = "cpu",
) -> None:
super().__init__()
self.s2t = Speech2Text.from_pretrained(
model_tag=model_tag,
lang_sym=lang_sym,
device=str(device),
)
state = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
self.s2t.s2t_model.load_state_dict(
{k.replace("s2t_model.", ""): v for k, v in state.items() if k.startswith("s2t_model.")}
)
def forward(self, speech):
return self.s2t(speech)
def output_fn(*, data, model_output, idx):
uttid = data.get("uttid", str(idx))
hyp = model_output[0][3]
ref = detokenize(data.get("text_raw", ""))
return {"uttid": uttid, "hyp": hyp, "ref": ref}
infer_yaml_cfg = f"""
infer_dir: {EXP_DIR}/infer
dataset:
_target_: espnet3.components.data.data_organizer.DataOrganizer
test:
- name: test
dataset:
_target_: __main__.FLEURSDataset
split: test
data_lang: {FLEURS_CONFIG}
lang_sym: {LANG_SYM}
task_sym: {TASK_SYM}
model:
_target_: __main__.OWSMV4BaseInferenceModel
model_tag: {MODEL_TAG}
lang_sym: {LANG_SYM}
checkpoint_path: {EXP_DIR}/last.ckpt
device: {DEVICE}
input_key: speech
output_fn: __main__.output_fn
"""
infer_cfg = OmegaConf.create(infer_yaml_cfg)
print("Config ready. infer_dir:", infer_cfg.infer_dir)system = BaseSystem(infer_config=infer_cfg)
run_stages(system, ["infer"], log=log)10) Measure (optional)
Compute CER and WER on the test split using the fine-tuned checkpoint.
Related docs:
measure_yaml_cfg = f"""
infer_dir: {EXP_DIR}/infer
dataset:
_target_: espnet3.components.data.data_organizer.DataOrganizer
test:
- name: test
dataset:
_target_: __main__.FLEURSDataset
split: test
data_lang: {FLEURS_CONFIG}
lang_sym: {LANG_SYM}
task_sym: {TASK_SYM}
metrics:
- metric:
_target_: espnet3.systems.asr.metrics.wer.WER
clean_types:
- metric:
_target_: espnet3.systems.asr.metrics.cer.CER
clean_types:
"""
measure_cfg = OmegaConf.create(measure_yaml_cfg)
print("Config ready.infer_dir:", measure_cfg.infer_dir)
system = BaseSystem(measure_config=measure_cfg)
run_stages(system, ["measure"], log=log)