Converting the YAMNet audio detection model to TensorFlow Lite
For machine learning in the field of acoustics, most state-of-the-art models rely on recognising patterns in spectrograms using convolutional neural networks (CNNs). At Rainforest Connection, we use CNNs to detect the sound of chainsaws, vehicles and other potentially illegal activities in the rainforest and other protected areas. While inferences are currently performed on servers in the cloud, we are excited at the prospect of inference at the edge on sensors in the field for faster event detection. To achieve this, we need efficient audio detection models capable of running on IoT or mobile devices.
YAMNet is an acoustic detection model that classifies 521 different sounds, trained by Dan Ellis on the labelled audio from more than 2 million YouTube videos (AudioSet). The YAMNet model is included in the research section of the TensorFlow models repo. It’s an awesome project.
I’m going to walkthrough converting the YAMNet model to a TensorFlow Lite model that can be run on mobile devices (e.g. deployed to Android or iOS as a Firebase ML Custom Model). 🚀
1. Setup
Download the TensorFlow models repository from GitHub and make the yamnet
folder your working directory. (I moved it out and deleted the rest of the repo as it’s not needed for the rest of this experiment.)
Get your favourite python environment setup with Tensorflow 2.2. You can use virtual environments or Anaconda if you prefer, but I find it easier to use Docker:
docker run -it --rm -v ${PWD}:/app -w /app tensorflow/tensorflow:2.2.0
Next, you need to download the model weights for YAMNet:
curl -o yamnet.h5 https://storage.googleapis.com/audioset/yamnet.h5
Then you can run YAMNet’s test script python yamnet_test.py
. If you see Ran 4 tests…
and OK
then you have everything you need!
Next, let’s see the outputs when running the model on an audio file. The inference script will resample the file it if it does not match the 16kHz sample rate of the YAMNet model. Make sure you have soundfile
and resampy
installed. If you setup your environment with Docker then simply:
apt-get install -y libsndfile1
pip install resampy soundfile
Test the model on an audio file:
python inference.py examples/baby_5000ms.wav
Expected output:
examples/baby_5000ms.wav :
Crying, sobbing: 0.555
Baby cry, infant cry: 0.485
Babbling : 0.400
Child speech, kid speaking: 0.100
Inside, small room: 0.055
We have a working YAMNet model on TensorFlow, the remainder of this article is dedicated to getting the same results on TensorFlow Lite. 👩🏻💻
2. First try at TF Lite converter
As described in the guide Get Started with TensorFlow Lite, there are a few different ways to convert model to the TFLite format:
- SavedModel directories
- Keras models
- Concrete functions
Looking at the yamnet.py
file where the model is defined, we see that it is a Keras model. In my first try at converting, I simply passed the yamnet model into the TFLiteConverter.from_keras_model
function as follows.
When you run it python convert_fail1.py
, it errors:
ValueError: None is only supported in the 1st dimension. Tensor 'input_1' has invalid shape '[1, None]'.
Lesson #1: TFLite doesn’t support dynamically-sized (None) dimensions. So we won’t be able to pass an arbitrary length audio file to a TFLite model as the shape of the input needs to be defined during compilation/conversion. (That’s not completely true as TFLite has added support for dynamically-sized inputs in the new experimental converter, but I was unable to get it working on a mobile device). Hence we need to pick the length of the audio clip. Looking at the params.py
file, the default sample rate is 16,000
and the window length is 0.96
seconds. So originally I incorrectly guessed the input shape to be [1, 15360]
. To explain why it’s incorrect, we can look at the YAMNet documentation and consider the spectrogram conversion…
“A spectrogram is computed using magnitudes of the Short-Time Fourier Transform with a window size of 25 ms, a window hop of 10 ms, and a periodic Hann window.”
I understood this as the raw audio waveform has a sliding window of 0.025s pieces, and we need sufficient windows to cover 0.96s. Hence the input length in seconds should be a multiple of 0.025, and it must be at least 0.96 seconds. It turns out the minimum is 0.975 seconds, or 15,600 samples. (I don’t know why the window length is not a multiple of the STFT frames, or why it wasn’t rounded up to 1 second — I suspect there is a mathematical explanation!)
To specify this, we copy the yamnet_frames_model
function from yamnet.py
and modify it so that the input shape is fixed at [1,15600]
.
Now when we run python convert_fail2.py
we get:
error: failed while converting: ‘main’: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): RFFT,ComplexAbs.
Lesson #2: The RFFT and ComplexAbs operators are not supported in tflite. These are not actually used in the CNN, but are part of the spectrogram generation.
Next up I spent a lot of time trying to get things working with SELECT_TF_OPS
by adding this line:
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
The conversion completed successfully and I thought we were done — but I wasn’t able to get the converted TFLite file to load on Android or iOS. (Nor does it load in Python because TensorFlow doesn’t support inference with the SELECT_TF_OPS
and hence local testing was impossible — suggest avoiding the select ops unless you know what you are doing.) Several hours of frustrated debugging followed!
3. Generating spectrograms the TFLite way
After several cups of tea, I discovered the solution is to avoid using the unsupported ops. For this I have to thank a long conversation on a GitHub issue which showed the way forward: writing your own TFLite-compatible implementation of tf.signal.stft
. Over at the Magenta project (Make Music and Art using Machine Learning), they have done exactly that.
In the YAMNet source code, on line 35 of features.py (inside the function waveform_to_log_mel_spectrogram
) you will find the call to tf.signal.stft
that we need to replace.
magnitude_spectrogram = tf.abs(tf.signal.stft(
signals=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length))
In a moment we will replace it with a modified version of the _stft_magnitude_tflite
function from the Magenta project. Let’s copy only the functions we need into a new file features_tflite.py
. Its quite long so please download it here and place the file in your working directory.
Then in our next attempt we use the waveform_log_mel_spectrogram
function from our new file.
python convert.py
Voilà! It converts successfully and the yamnet.tflite
file is saved. 😍🎉🕺🏻
Lesson #3: Make your own waveform-to-spectrogram function based on a STFT TFLite-compatible function. Thank you Magenta folks! (Funny aside: after I wrote this article I found out that the STFT functions were originally written by Dan Ellis himself!)
Inspect yamnet.tflite
in Netron (awesome tool for visualising ML models) and check the inputs and outputs look reasonable.
- Input:
[1,15600]
(raw audio samples of 0.975 seconds at 16kHz sample rate) - Output 0:
[1,521]
(predictions for the 521 classes) - Output 1:
[96,64]
(log mel spectrogram — you might not need this)
4. Testing
Recall the inference.py
script we used to test the model in part 1, and let’s now make a version that loads the model from the saved TFLite file.
If you test it on the same audio file we used in part 1:
python inference_tflite.py examples/baby_5000ms.wav
Then you will get an error:
Error when checking input: expected input_1 to have shape (15600,) but got array with shape (80000,)
Remember that earlier we set a fixed shape for the input of the TFLite model. We will need an audio file that is exactly 0.975 seconds in duration. Try again with a correct length audio file:
baby_975ms.wav :
Crying, sobbing: 0.930
Baby cry, infant cry: 0.862
Whimper : 0.157
Babbling : 0.112
Inside, small room: 0.030
Run the same baby_975ms.wav
file on the original model inference.py
and you should get exactly the same results — we can be confident that it is working correctly. Furthermore, we have a test wav file and the expected outputs that we can use to verify it is running correctly on a mobile device.
What next?
You now have a TFLite version of YAMNet that you can add to an Android or iOS application. I found the easiest way to run it was using Firebase ML Custom Models. The following tutorials are a good place to start.
- Using TensorFlow Lite and ML Kit to build custom machine learning models for Android
- Firebase ML Kit Custom Models for iOS developers — Part 1: Understanding TensorFlow Lite
- Firebase ML Kit Custom Models for iOS developers — Part 2: Implementing Tic Tac Toe
The complete project is available on GitHub. This has been a fun side-project for learning more about ML on mobile — welcome your comments and suggestions. 😊