Speaker Diarisation Using Transfer Learning

Speaker diarisation is the process of classifying different parts of conversations with their corresponding speaker. It is used in various smart home devices such as the Google Home to detect who is speaking (instead of what is being said, which is a different problem). Speaker diarisation is still an open problem, but rapid progress is being made using modern deep learning techniques. At UBC Launch Pad this year, we’ve been building a library for doing speaker diarisation in Python — we’ve called it Minutes.

The Problem

There are two major reasons why speaker diarisation is a difficult problem. The speaker diarisation system must be able handle large variation in audio quality samples if you are not able to control the environmental factors in the data ingestion process (which is likely). Furthermore, a model may need to predict on classes (people) for which only a small sample of training data is available. It may be inconvenient to extract too large a training sample for new classes. Whereas a typical object prediction model may have seen the same object thousands of times in its training process, your Google Home must learn to distinguish your voice from others with only a small sample of your voice. Transfer learning is a technique that helps with this latter problem.

Consider a company meeting as an example. Each of the participants have pre-recorded a small amount of speech using a client. Later, an entire meeting is recorded and transmitted to a server. Server-side, this conversation is split up by speaker and transcribed to text, using a model trained on the small speech samples. The problem for the server is to quickly and economically produce this new model, predict on the new set of classes, all without relearning everything there is to know about voices.

Our Dataset

Due to the recent advances in audio transcription, several large datasets have become available. The LibriSpeech ASR Corpus is a large corpus of read English speech. For this project, we simply broke audio files into 1 second (48,000 sample) intervals and labelled each interval with a speaker ID to generate labelled training data. This observation width is a hyperparameter, we call it the samples per observation. Below is an image of 10 of these observations concatenated together.

After splitting the corpus, we converted each observation into a spectrogram. Since image recognition is a well developed area of machine learning, using spectrograms gave us the opportunity to leverage plenty of other research and advanced techniques using convolutional neural networks (CNNs).

Transfer Learning

It’s easy to generate a neural network that predicts quite well on a small set of classes using this dataset. Consider this small implementation below using the Keras library (it achieves 97% validation accuracy on 5 classes, after training for 15–25 epochs).

model = Sequential()
# Add several convolutional layers with dropout.
model.add(Conv2D(32, (32, 4), strides=(16, 4),
input_shape=(1025, 32, 3), activation='relu'))
model.add(Dropout(0.5))
model.add(Conv2D(64, (8, 5), strides=(4, 2),
activation='relu'))
model.add(Dropout(0.2))
model.add(Conv2D(128, (1, 1), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
# Flatten and add dense layers.
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(y_train[0].size, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'])
model.fit(X_train, y_train, validation_data=(X_val, y_val), 
epochs=50, batch_size=400, verbose=2, shuffle='batch')

But of course, this model is not able to predict on classes that it has not yet seen.

Now we’ll introduce transfer learning, a technique for reusing the base model generated above on brand new datasets (speeding up training of subsequent models). Depending on the hardware used and the size of the training dataset, much time may have been spent to generate a model any more complex than the one above. Our goal is to maximize reuse of the base model. Keras makes this easy — simply freeze the layers you do not want to retrain, resize the final layer, and pass in a small amount of new data.

model = load_model('myCNN.h5')
# Layer freezing.
for k in model.layers[:-1]:
k.trainable = False
(X_train, y_train), (X_val, y_val) = load_data(
'features.npy', 'labels.npy')
# Pop and resize final layer(s).
model.layers.pop()
model.add(Dense(y_train[0].size, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=[‘accuracy’])
model.fit(X_train, y_train, validation_data=(X_val, y_val), 
epochs=15, batch_size=128, verbose=2)

A nice proxy for the amount of the base model we are able to reuse is something we’re calling base model utilization. You can compute this value by printing the Keras model summary counting the parameters.

utilization = trainable params / total params

Increasing the utilization has a direct impact on the on the cost of creating the subsequent transfer model, because fewer total parameters must be optimized during training.

Notice that the larger the utilization, the faster we’re able to achieve 85%+ accuracy (15 epochs with 4% utilization, 10 epochs with 21% utilization, and 5 epochs with 47% utilization). Image recognition CNNs may have dense layers at the end with many parameters — popping and retraining these large layers seriously reduces the utilization. Some simple tricks we found to reduce this effect were:

  • Increasing pooling early on, reducing the size of the later parameter matrices.
  • Adding additional convolution layers reducing dimensionality downstream.

It’s important to recognize that using transfer learning, we’re predicting on a brand new set of classes, and seriously reducing the time it takes to train such a model.

YouTube Accuracy

We also collected conversations from YouTube and parsed the time-stamped (and author-stamped) transcripts to generate new supervised training datasets from “the wild”. Using the transfer learning approach described above, we were only able to achieve 60% accuracy on conversations with 3–4 participants, such as this SciShow talk. Our hypotheses on the low accuracy are: either the LibriVox corpus provided samples that were “too clean” in comparison to data in the wild, or that the YouTube timestamps were imperfectly labelled.

For the time being, transfer learning may be ideal in binary classification cases or cases where the data is relatively clean, such as a phone call with two participants (instead of a conference call or a noisy meeting room discussion). Also, simple data augmentation techniques may make the model more robust to varying recording quality.

What’s next?

Minutes is intended as a library for speech diarisation. Our hope is that transfer learning can improve the speed with which new speech diarisation models are learned, by reusing the base model as much as possible. Keep an eye out for announcements regarding the Minutes Python library and our dataset on our facebook page. We’d like to thank the authors here for nudging us in the direction of transfer learning on this project.