What I Learned from Whisper Fine-Tuning Event

Further boost Whisper’s performance on your own data.

bofeng huang
10 min readJan 25, 2023
BY GEORGE MARKS/RETROFILE/GETTY IMAGES.

TL;DR

OpenAI’s Whisper promises to “approach human level robustness and accuracy on English speech recognition”. In this blog, we will explain how to further improve its performance on other languages. It allowed us to win first prize in the Whisper Fine-Tuning Event held by HuggingFace 🤗 and Lambda Labs, both on French language and German language. The models and the demos are available on the Hugging Face Hub.

Preface

I am a machine learning engineer at Zaion, a French company that is leading the European market of customer relation solutions. Zaion’s goals include providing accurate and precise transcription of customer service conversations. Thus, it is crucial for us to have a reliable speech recognition system, that is robust with different real-world environments.

At Zaion Lab, we are constantly on the lookout for the latest trends and novelties in the field of speech recognition. This gave me the opportunity to participate in the Whisper Fine-Tuning Event held by HuggingFace 🤗 and Lambda Labs, which aims to democratize Whisper models to as many languages as possible. I have taken part in the challenge of French language and German language, and won first prize for both (leaderboard).

Introduction

In September 2022, OpenAI released a pre-trained automatic speech recognition (ASR) model called Whisper. Self-supervised learning models, such as wav2vec 2.0, are usually pre-trained on masked prediction tasks using unlabelled audio data, then fine-tuned on labelled data for various downstream tasks including ASR. In contrast, Whisper models are trained directly on a large amount of weakly labelled data collected from the web.

To be precise, it is 680,000 hours of multilingual and multitask data that includes transcription in multiple languages, translation from those languages into English, as well as timestamp prediction. When scaled to this magnitude, the model demonstrates a strong robustness to accents, background noise and technical language.

Whisper is a sequence-to-sequence model, a Transformer based encoder-decoder — which maps a sequence of log-magnitude Mel spectrogram features to a sequence of byte-level BPE tokens. The log-Mel spectrogram features are computed by the feature extractor from raw waveform, then encoded by the Transformer encoder. The decoder autoregressively predicts the next token conditionally on the previous tokens and the encoder hidden states. The figure below summarizes the architecture of the model.

The architecture of Whisper model. Figure source: OpenAI Whisper Blog.

In this blog, we will show how to fine-tune Whisper on French language with the medium checkpoint, that has 24 encoder and decoder layers respectively and 769 million parameters. The full code can be found here.

Prepare Data and Model

Load Model

Let’s first load the Whisper’s pre-trained medium checkpoint:

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")

You will find some defined arguments in Whisper model such as forced_decoder_ids and suppress_tokens. These arguments are defined in GenerationConfig for the generation task. However, we override these arguments during the training in order to let the model learn them by itself.

We also disable the use_cache feature in the Whisper decoder. It allows us to re-use the computed key and values of the self-attention and the cross-attention blocks to speed up the current decoding step. However it’s incompatible with the gradient checkpointing which will be applied in a later step to reduce the memory footprint.

# included in the training
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
# to use gradient checkpointing
model.config.use_cache = False

Load Datasets

We will use the 🤗 Datasets library to download and prepare the datasets. We mix the training split of Common Voice 11.0 and Multilingual LibriSpeech to form a larger training set, and only use the test split of Common Voice 11.0 for evaluation.

It is always encouraged to collect as much training data as you can. There are other available speech recognition datasets on the Hugging Face Hub, such as Voxpopuli and Fleurs. If you want to load your local corpus, take a look at this page.

The sampling rate of audio signal is 48kHz in Common Voice and 16kHz in Multilingual LibriSpeech. We make sure that the audio samples are resampled to 16kHz, not only to unify the sampling rate from different datasets, but also as it is the sampling rate of 680,000 hours of pre-training corpus of Whisper models. The resampling can be easily accomplished on the fly using the Datasets’ cast_column method and the Audio feature.

In order to mix different datasets, it is also necessary to ensure that all the datasets have the same data fields. Here we only keep the column audio and sentence of two datasets.

from datasets import Audio, DatasetDict, concatenate_datasets, load_dataset

AUDIO_COLUMN_NAME = "audio"
TEXT_COLUMN_NAME = "sentence"


def normalize_dataset(ds, audio_column_name=None, text_column_name=None):
if audio_column_name is not None and audio_column_name != AUDIO_COLUMN_NAME:
ds = ds.rename_column(audio_column_name, AUDIO_COLUMN_NAME)
if text_column_name is not None and text_column_name != TEXT_COLUMN_NAME:
ds = ds.rename_column(text_column_name, TEXT_COLUMN_NAME)
# resample to the same sampling rate
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
# normalise columns to ["audio", "sentence"]
ds = ds.remove_columns(set(ds.features.keys()) - set([AUDIO_COLUMN_NAME, TEXT_COLUMN_NAME]))
return ds

raw_datasets = DatasetDict()

ds_train_mcv = load_dataset("mozilla-foundation/common_voice_11_0", "fr", split="train+validation", use_auth_token=True)
ds_train_mcv = normalize_dataset(ds_train_mcv)

ds_train_mls = load_dataset("facebook/multilingual_librispeech", "french", split="train+validation")
ds_train_mls = normalize_dataset(ds_train_mls, text_column_name="text")

raw_datasets["train"] = concatenate_datasets([ds_train_mcv, ds_train_mls])
# NB: shuffle concatenated dataset
raw_datasets["train"] = raw_datasets["train"].shuffle(seed=10)

raw_datasets["eval"] = load_dataset("mozilla-foundation/common_voice_11_0", "fr", split="test", use_auth_token=True)
raw_datasets["eval"] = normalize_dataset(raw_datasets["eval"])

If you have some disk space constraints, you may want to load the datasets on the fly with streaming mode.

Data Augmentation

We have noticed that the audio samples in the Multilingual LibriSpeech dataset are quite clear. In order to keep the model robust in noisy settings and generalize well across different speakers, we perform the data augmentation using Audiomentations library. Several augmentations are applied to audio samples, including TimeStretch, Gain, PitchShift, and one of AddBackgroundNoise or AddGaussianNoise.

The augmentation is defined as shown below:

from audiomentations import (
AddBackgroundNoise,
AddGaussianNoise,
Compose,
Gain,
OneOf,
PitchShift,
PolarityInversion,
TimeStretch,
)

musan_dir = "./musan"

# define augmentation
augmentation = Compose(
[
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=False),
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1),
PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
OneOf(
[
AddBackgroundNoise(sounds_path=musan_dir, min_snr_in_db=1.0, max_snr_in_db=5.0, noise_transform=PolarityInversion(), p=1.0),
AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0),
],
p=0.2,
),
]
)


def augment_dataset(batch):
# load and (possibly) resample audio data to 16kHz
sample = batch[AUDIO_COLUMN_NAME]

# apply augmentation
augmented_waveform = augmentation(sample["array"], sample_rate=sample["sampling_rate"])
batch[AUDIO_COLUMN_NAME]["array"] = augmented_waveform
return batch

Then we apply the augmentation to all the training examples using the map method:

# augment training data
augmented_raw_training_dataset = raw_datasets["train"].map(
augment_dataset, num_proc=preprocessing_num_workers, desc="augment train dataset"
)

# combine
raw_datasets["train"] = concatenate_datasets([raw_datasets["train"], augmented_raw_training_dataset])
raw_datasets["train"] = raw_datasets["train"].shuffle(seed=10)

Note: The data augmentation is only performed on the training set. We also keep an original version of training set, then compose it with the augmented training set.

Normalize Text

While diversity in audio quality can help train a model to be robust, diversity in transcript quality is not similarly beneficial.

The diversity here is reflected on the transcription format, i.e., the case and the punctuation exist in Common Voice dataset, but not in Multilingual LibriSpeech dataset. We should make sure that the transcriptions are in lowercase and remove all punctuation when using them together. This will simplify the task — since the model doesn’t need to distinguish uppercase and lowercase characters anymore, or predict punctuation marks between characters.

However, if you want to have transcriptions that are easy to read or need case or punctuation, it’s better to keep them and only use cased and punctuated datasets such as Common Voice and Fleurs.

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

Note: The model is always evaluated on the normalized transcriptions which means un-cased and un-punctuated.

You can find the normalization used for English and other languages in the Appendix C of the Whisper paper.

Preprocessing

As shown in the introduction, Whisper model takes log-Mel spectrogram as input and outputs BPE tokens. Thus, we need to prepare our data into the adequate format. It can be achieved by two utility classes WhisperFeatureExtractor and WhisperTokenizer, respectively used on the audio inputs and the transcriptions or model predictions. The transformers library wraps both of classes into a single WhisperProcessor class, which can be loaded as shown below:

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="french", task="transcribe")

We just need to specify the target language and the task so that WhisperTokenizer will prefix the corresponding language and task tokens when encoding the transcriptions to label ids.

Let’s see what is in our data preparation function:

do_normalize_text = True


def prepare_dataset(batch):
# load
audio = batch[AUDIO_COLUMN_NAME]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]

# process targets
input_str = normalizer(batch[TEXT_COLUMN_NAME]).strip() if do_normalize_text else batch[TEXT_COLUMN_NAME]
# encode target text to label ids
batch["labels"] = processor.tokenizer(input_str).input_ids

return batch

Then we apply the data preparation function to all the examples in the dataset using the map method:

vectorized_datasets = raw_datasets.map(
prepare_dataset,
num_proc=preprocessing_num_workers,
remove_columns=next(iter(raw_datasets.values())).column_names,
desc="preprocess dataset",
)

Remove Long Audio

In the former step, the examples with audio longer than 30s have been truncated by the WhisperFeatureExtractor. While the audio is truncated, the transcription is not, this will severely destabilize the training. Here we define a function to filter any audio longer than 30s:

max_input_length = 30
min_input_length = 0


def is_audio_in_length_range(length):
return length > min_input_length and length < max_input_length

We then apply our filter function to all the examples using the filter method:

vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range, num_proc=preprocessing_num_workers, input_columns=["input_length"]
)

Remove Long Text

Whisper decoder uses a learned position embeddings which has the max length of 448 tokens. Therefore it cannot decode any transcription more than 448 label ids. Here we define a filter function on the label ids:

max_label_length = model.config.max_length


def is_labels_in_length_range(labels):
return len(labels) < max_label_length

Then apply it to all the examples through the filter method:

vectorized_datasets = vectorized_datasets.filter(
is_labels_in_length_range, num_proc=preprocessing_num_workers, input_columns=["labels"]
)

Training and Evaluation

Data Collator

The data collator takes a list of pre-processed samples and collates them into a batch of Pytorch tensors. We must make sure that all the audio features in the batch have the same length, and this rule also applies to all the labels in the batch.

The audio features are already padded or truncated to a fixed dimension by WhisperFeatureExtractor, so we just need to convert them to Pytorch tensors using the pad method.

On the other hand, the label ids are un-padded. We first need to pad them to the maximum length in the batch using the pad method, then replace the padding tokens by -100 so that these tokens are not taken into account when computing the loss.

Let’s define our data collator as follows:

from dataclasses import dataclass
from typing import Any, Dict, List, Union

import torch


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
# convert to tensors
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad label ids to the max length in the batch
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]

batch["labels"] = labels

return batch

Then we can initialize the data collator we’ve just defined:

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Evaluation Metrics

We will use the word error rate (WER) metric to evaluate the model performance. The WER metric can be simply loaded by 🤗 Evaluate:

import evaluate

metric = evaluate.load("wer")

We then need to define a function that takes the real label ids and the model predictions then returns the WER metric. In this function we will have to replace -100 by the pad_token_id (undoing the step in the data collator to ignore padded tokens) so that the label ids can be correctly de-tokenized into strings.

# evaluate with the 'normalized' WER
do_normalize_eval = True


def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids

# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id

# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
# perhaps already normalised
label_str = [normalizer(label) for label in label_str]
# filtering step to only evaluate the samples that correspond to non-zero references
pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]

wer = metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}

Training Configuration

In this step we define all the parameters related to training. For more detail on the other training arguments, refer to the Seq2SeqTrainingArguments docs.

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
output_dir="./outputs/whisper_medium_ft",
per_device_train_batch_size=64,
per_device_eval_batch_size=32,
gradient_accumulation_steps=1,
warmup_steps=800,
max_steps=8000,
learning_rate=6.25e-6,
weight_decay=0.01,
gradient_checkpointing=True,
fp16=True,
predict_with_generate=True,
generation_max_length=225,
logging_steps=25,
report_to=["tensorboard"],
evaluation_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
)

Training

In the final step, we will initialize the Trainer by passing the model, dataset, data collator, training arguments and metrics computation function.

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["eval"],
tokenizer=processor,
data_collator=data_collator,
compute_metrics=compute_metrics,
)

Let’s launch the training!

trainer.train()

Don’t forget to save your model and processor after the training is finished:

model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

Summary

In this blog, we covered a step-by-step guide on fine-tuning Whisper for ASR on French data. The WER of Whisper medium checkpoint has been reduced from 16.00% to 9.03% on Common Voice. With the large checkpoint it has gone down from 13.90% to 8.15%. You can find here a demo for French ASR using fine-tuned Whisper models.

You can also fine-tune Whisper on other languages: you just need to collect and clean datasets in that language, then specify the corresponding language code when loading WhisperProcessor.

References

  1. Robust Speech Recognition via Large-Scale Weak Supervision
  2. Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers
  3. Whisper Fine-Tuning Event

Thanks for the review and comments of my dear colleagues Mohamed Bouaziz, Imed Laaridh, Lorraine Vanel, Yingzhi Wang and Alya Yacoubi!

--

--