Wavelet Transforms in Python with Google JAX
A simple data compression example
Wavelet transforms are one of the key tools for signal analysis. They are extensively used in science and engineering. Some of the specific applications include data compression, gait analysis, signal/image de-noising, digital communications, etc. This article focuses on a simple lossy data compression application by using the DWT (Discrete Wavelet Transform) support provided in the CR-Sparse library.
For a good introduction to wavelet transforms, please see:
- Wikipedia, Wavelet transform
- THE WAVELET TUTORIAL, PART I
- C. Valens, A really friendly guide to wavelets
- Vidakovic and Mueller, Wavelets for kids, A tutorial introduction
Wavelets in Python
There are several packages in Python which have support for wavelet transforms. Let me list a few:
- PyWavelets is one of the most comprehensive implementations for wavelet support in python for both discrete and continuous wavelets.
- pytorch-wavelets provide support for 2D discrete wavelet and 2d dual-tree complex wavelet transforms.
- scipy provides some basic support for the continuous wavelet transform.
PyWavelet is probably the most mature library available. Its coverage and performance are great. However, major parts of the library are written in C. Hence, retargeting the implementation for GPU hardware is not possible. That is one of the reasons for people coming up with newer implementations e.g. on top of PyTorch which provides the necessary GPU support.
Google JAX
The CR-Sparse library now includes support for computing discrete and continuous wavelet transforms using the Google JAX library.
JAX provides high-performance numerical computing by taking advantage of XLA. XLA is a domain-specific compiler for linear algebra. JAX provides a NumPy-like API and a JIT compiler so that code written using JAX can be easily just in time compiled (using XLA) to machine code for specific hardware architecture.
Thus, you can write pure Python code on top of JAX API and build sophisticated numerical algorithms which can get cross-compiled to different GPU architectures efficiently.
JAX has tools like pmap which makes parallel evaluation of code straightforward. For large datasets, JAX easily outperforms NumPy even on CPU.
However, getting advantage of JAX does require some work. We have to write our numerical algorithms in a manner so that they can be JIT-compiled. One specific need is that all the code is written using functional programming principles. E.g. the JAX arrays are immutable (while NumPy arrays are not) so any change to an array actually creates a new array at the Python level of code. XLA compiler is smart enough to reuse memory. In a way, rewriting numerical algorithms in a functional manner is quite a rewarding experience. It lets you focus on the essential mathematics, avoids unnecessary global state manipulation, keeps the implementation quite clean and simple.
Wavelet Support in CR-Sparse
CR-Sparse focuses on functional models and algorithms for sparse signal processing, i.e. exploiting the sparsity of signal representations in signal processing problems. Wavelet transforms are a key tool for constructing sparse representations of common signals. Thus, they form an important part of the CR-Sparse library. The implementation is pure Python, written using functional programming principles followed by JAX, and it gets just in time compiled to CPU/GPU/TPU architectures seamlessly giving excellent performance. The wavelet module API is inspired by and is similar to PyWavelets. In addition, the wavelet functionality has been wrapped as 1D and 2D linear operators similar to PyLops.
Please refer to my previous article Implementing Linear Operators in Python with Google JAX for more information about linear operator design.
Simple data compression with wavelets
Decomposition and Reconstruction
Wavelet transforms are invertible.
- We can decompose a signal using a wavelet to obtain the wavelet coefficients using an algorithm called discrete wavelet transform (DWT). The signal is decomposed into two sets of coefficients: the approximation coefficients (low pass component of a signal) and detail coefficients (high frequency.
- We can reconstruct the signal back from the wavelet coefficients using an algorithm called inverse discrete wavelet transform (IDWT).
Multi-level decomposition
- Usually, the wavelet decomposition is done multiple times.
- We start with the signal, compute approximation and detail coefficients, then apply the DWT again on the approximation coefficients.
- We repeat this process multiple times to achieve a multi-level decomposition of the signal.
The following example shows a 4 level decomposition
X => [A1 D1] => [A2 D2 D1] => [A3 D3 D2 D1] => [A4 D4 D3 D2 D1]A1, D1 = DWT(X)
A2, D2 = DWT(A1)
A3, D3 = DWT(A2)
A4, D4 = DWT(A3)
- Signal X is split into approximation and detail coefficients A1 and D1 by applying DWT. If the signal has N coefficients, then the decomposition will have N/2 approximation coefficients and N/2 detail coefficients (technically, if we use periodization extension while computing the DWT, other extensions lead to more coefficients).
- The approximation coefficients A1 have split again into approximation and detail coefficients A2 and D2 by applying DWT.
- We repeat this process 2 more times.
- The 4 level decomposition of X is obtained by concatenating the coefficients in A4, D4, D3, D2, D1.
- If the signal X has N samples, then the wavelet decomposition will also consist of N coefficients.
The reconstruction of the signal proceeds as follows:
A3 = IDWT(A4, D4)
A2 = IDWT(A3, D3)
A1 = IDWT (A2, D2)
X = IDWT(A1, D1)
Simple data compression
- It turns out that the signal can still be faithfully reconstructed with pretty high SNR if we drop some of the detail coefficients.
- If we drop D1 coefficients, we achieve 50% compression. If we drop both D1 and D2 coefficients, we can achieve 75% compression.
An important consideration is to measure the signal-to-noise ratio after reconstructing the signal from the remaining coefficients. If the compression technique is good, the SNR will be high.
This is a very simplistic treatment of the compression problem but it will suffice for the purposes of this article.
We now show the sample code for using 1D and 2D wavelet transform for signal and image compression and reconstruction. The detailed example code is available in the examples gallery (in the CR-Sparse documentation) here.
This example runs against the latest version in the repository which can be installed with:
python -m pip install git+https://github.com/carnotresearch/cr-sparse.git
1D signal decomposition, compression, reconstruction
First, the necessary imports. We will also need JAX, matplotlib, and scikit-image libraries.
import jax.numpy as jnp
import matplotlib.pyplot as plt
import cr.sparse as crs
from cr.sparse import lop
from cr.sparse import metrics
import skimage.data
from cr.sparse.dsp import time_values
We will construct a signal consisting of multiple sinusoids at different frequencies and amplitudes for this example.
fs = 1000.
T = 2
t = time_values(fs, T)
n = t.size
x = jnp.zeros(n)
freqs = [25, 7, 9]
amps = [1, -3, .8]
for (f, amp) in zip(freqs, amps):
sinusoid = amp * jnp.sin(2 * jnp.pi * f * t)
x = x + sinusoid
The CR-Sparse linear operator module (lop) includes a 1D wavelet transform operator. We will construct the operator. We will provide the size of the signal, the wavelet type, and the number of levels of decomposition as parameters for this operator.
DWT_op = lop.dwt(n, wavelet='dmey', level=5)
Wavelet coefficients are computed by applying the linear operator to the data. Read here for learning how linear operators work in CR-Sparse.
alpha = DWT_op.times(x)
It is interesting to note that most of the detail coefficients are negligible. The magnitude of a wavelet coefficient indicates the portion of signal energy carried by that coefficient. Dropping these coefficients shouldn’t lead to a high reconstruction error.
Let’s drop all but 10% of the coefficients (compression):
cutoff = n // 10
alpha2 = alpha.at[cutoff:].set(0)
For our purposes, we just set those coefficients to 0. In a digital communication setup, those coefficients won’t be transmitted and will be assumed to be zero by the receiver. We also see a little difficult syntax for array updates. Since arrays are immutable in JAX, hence JAX provides functional variants for constructing a new array from an old array by updating parts of it. See here for details.
We now reconstruct the original signal from the remaining coefficients by applying the adjoint of the DWT linear operator (which happens to be its inverse).
x_rec = DWT_op.trans(alpha2)
snr = metrics.signal_noise_ratio(x, x_rec)
print(snr)
SNR is 36.56 dB.
We can see from this plot that the reconstruction error is negligible.
2D image decomposition, compression, reconstruction
Let’s try our luck on a 2D image now. We will take a sample grass image from scikit-image library for this demo.
image = skimage.data.grass()
2D DWT is a straightforward extension of 1D DWT.
- Given an image X of size say NxN, compute DWT of each column. We get two new images CA and CD of size N/2 x N each (i.e. half the number of rows).
- Apply DWT on each row of CA to obtain CAA and CAD images (of size N/2 x N/2 each).
- Apply DWT on each row of CD to obtain CDA and CDD images.
- This way we split X into [CAA, CAD, CDA, CDD] 4 sub-images.
- We can combine these sub-images to form a single coefficients image.
- We repeat the 2D DWT decomposition on the CAA part recursively to compute the multilevel decomposition.
The 2D IDWT takes [CAA, CAD, CDA, CDD] as input and returns X as output (by first applying IDWT on rows and then IDWT on columns).
We will use a 2D Haar wavelet transform operator with 5 levels of decomposition.
DWT2_op = lop.dwt2D(image.shape, wavelet='haar', level=5)
DWT2_op = lop.jit(DWT2_op)
Computing the wavelet coefficients is about applying the linear operator on the image:
coefs = DWT2_op.times(image)
Let’s keep only 1/16 of the coefficients (i.e. just 6.25% of coefficients). We are dropping the first and the second levels of detail coefficients.
h, w = coefs.shape
coefs2 = jnp.zeros_like(coefs)
coefs2 = coefs2.at[:h//4, :w//4].set(coefs[:h//4, :w//4])
Reconstruction involves applying the adjoint of the operator which happens to be its inverse. After reconstruction, we will compute the peak signal to noise ratio to measure the quality of reconstruction.
image_rec = DWT2_op.trans(coefs2)
# PSNR
psnr = metrics.peak_signal_noise_ratio(image, image_rec)
print(psnr)
PSNR is 19.38 dB.
A 19 dB PSNR with just 6% of the wavelet coefficients is not bad. Also, the details of the image are well preserved and there are no blocking artifacts.
I hope that this article gives a good introduction to the wavelet transform capabilities available in CR-Sparse.
For more advanced usage, check out the image deblurring using LSQR and FISTA algorithms example.