TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Implementing Linear Operators in Python with Google JAX

--

Illustration by author

A review of linear operators

A linear operator or a linear map is a mapping from a vector space to another vector space that preserves vector addition and scalar multiplication operations. In other words, if T is a linear operator then T(x+y) = T(x) + T(y) and T (a x) = a T(x) where x and y are vectors and a is a scalar.

Linear operators have wide applications in signal processing, image processing, data sciences, and machine learning.

In signal processing, signals are often represented as linear combinations of sinusoids. The Discrete Fourier Transform is a linear operator which decomposes a signal in its individual component frequencies. Wavelet Transform is often used to decompose signals into individual location-scale specific wavelets so that interesting events or patterns inside a signal can be identified as well as localized easily.

In statistics, linear models are used to describe the observations or target variables as linear combinations of features.

We can think of a linear operator as a mapping from model space to data space. Every linear operator has a matrix representation. If a linear operator T is represented by a matrix A, then the application of a linear operator y = T(x) can be written as:

y = A x

where x is the model and y is the data. In Fourier Transform, A's columns are the individual sinusoids, and model x describes the contribution of each sinusoid to the observed signal y. Usually, we are given the data/signal y and our task is to estimate the model vector x. This is known as an inverse problem. The inverse problem is easy for orthonormal bases. The simple solution is x = A^H y. However, this doesn’t work if the model size is less than the data size or more. When the model size is less, we have an overfitting problem. A basic approach is to solve the least-squares problem:

minimize \| A x - y \|_2^2

This leads to a system of normal equations:

A^T A x = A^T y

Libraries like NumPy and JAX provide extensive support for matrix algebra. However, direct methods from matrix algebra are prohibitive from both time and space complexity perspectives for large systems. Storing A itself may be very expensive for very large matrices. Computing A^T A is an O(n³) operation. Computing its inverse for solving the normal equation can become infeasible as n increases.

Functional representation of linear operators

Fortunately, many linear operators which are useful in the scientific literature can be implemented in terms of simple functions. For example, consider a forward difference operator for finite-size vectors x in R⁸ (8 dimensional real vectors). A matrix representation is:

A = jnp.array([[-1.  1.  0.  0.  0.  0.  0.  0.]
[ 0. -1. 1. 0. 0. 0. 0. 0.]
[ 0. 0. -1. 1. 0. 0. 0. 0.]
[ 0. 0. 0. -1. 1. 0. 0. 0.]
[ 0. 0. 0. 0. -1. 1. 0. 0.]
[ 0. 0. 0. 0. 0. -1. 1. 0.]
[ 0. 0. 0. 0. 0. 0. -1. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]])

However the matrix-vector multiply computation A @ x can be written much more efficiently as:

import jax.numpy as jnp
def forward_diff(x):
append = jnp.array([x[-1]])
return jnp.diff(x, append=append)

This brings down the computation from O(n²) to O(n).

Forward and adjoint operators

In general, we need to implement two operations for a linear operator. The forward operator from the model space to the data space:

y = A x

And the adjoint operator from the data space to the model space:

x = A^T y

For complex vector spaces, the adjoint operator will be the Hermitian transpose:

x = A^H y

Existing implementations

SciPy provides a very good interface for implementing linear operators in scipy.sparse.linalg.LinearOperator. PyLops builds on top of it to provide an extensive collection of linear operators.

JAX-based implementation in CR-Sparse

JAX is a new library for high-performance numerical computing based on the functional programming paradigm. It enables us to write efficient numerical programs in Pure Python which can be compiled using XLA for CPU/GPU/TPU hardware for state-of-the-art performance.

CR-Sparse is a new open-source library being developed on top of JAX that aims to provide XLA accelerated functional models and algorithms for sparse representations-based signal processing. It now includes a good collection of linear operators built on top of JAX. Docs here. We represent a linear operator by a pair of functions times and trans. The times function implements the forward operation while the trans function implements the adjoint operation.

Getting started

You can install CR-Sparse from PyPI:

pip install cr-sparse

For the latest code, install directly from GitHub

python -m pip install git+https://github.com/carnotresearch/cr-sparse.git

In the interactive code samples below, the lines starting with > have the code and lines without > have the output.

First derivative operator

To create a first derivative operator (using forward differences):

> from cr.sparse import lop
> n = 8
> T = lop.first_derivative(n, kind='forward')

It is possible to see the matrix representation of a linear operator:

> print(lop.to_matrix(T))
[[-1. 1. 0. 0. 0. 0. 0. 0.]
[ 0. -1. 1. 0. 0. 0. 0. 0.]
[ 0. 0. -1. 1. 0. 0. 0. 0.]
[ 0. 0. 0. -1. 1. 0. 0. 0.]
[ 0. 0. 0. 0. -1. 1. 0. 0.]
[ 0. 0. 0. 0. 0. -1. 1. 0.]
[ 0. 0. 0. 0. 0. 0. -1. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]]

Computing the forward operation T x

> x = jnp.array([1,2,3,4,5,6,7,8])
> y = T.times(x)
> print(y)
[1. 1. 1. 1. 1. 1. 1. 0.]

Computing the adjoint operation T^H x

> y = T.trans(x)
> print(y)
[-1. -1. -1. -1. -1. -1. -1. 7.]

A diagonal matrix multiplication operator

Diagonal matrices are extremely sparse and linear operator-based implementation is ideal for them. Let’s build one:

> d = jnp.array([1., 2., 3., 4., 4, 3, 2, 1])
> T = lop.diagonal(d)
> print(lop.to_matrix(T))
[[1. 0. 0. 0. 0. 0. 0. 0.]
[0. 2. 0. 0. 0. 0. 0. 0.]
[0. 0. 3. 0. 0. 0. 0. 0.]
[0. 0. 0. 4. 0. 0. 0. 0.]
[0. 0. 0. 0. 4. 0. 0. 0.]
[0. 0. 0. 0. 0. 3. 0. 0.]
[0. 0. 0. 0. 0. 0. 2. 0.]
[0. 0. 0. 0. 0. 0. 0. 1.]]

Applying it:

> print(T.times(x))
[ 1. 4. 9. 16. 20. 18. 14. 8.]
> print(T.trans(x))
[ 1. 4. 9. 16. 20. 18. 14. 8.]

Under the hood

All linear operators are built as a named tuple Operator. See its documentation here. Below is a basic outline.

class Operator(NamedTuple):
times : Callable[[jnp.ndarray], jnp.ndarray]
"""A linear function mapping from A to B """
trans : Callable[[jnp.ndarray], jnp.ndarray]
"""Corresponding adjoint linear function mapping from B to A"""
shape : Tuple[int, int]
"""Dimension of the linear operator (m, n)"""
linear : bool = True
"""Indicates if the operator is linear or not"""
real: bool = True
"""Indicates if a linear operator is real i.e. has a matrix representation of real numbers"""

Implementation of the diagonal linear operator (discussed above) is actually quite simple:

def diagonal(d):
assert d.ndim == 1
n = d.shape[0]
times = lambda x: d * x
trans = lambda x: _hermitian(d) * x
return Operator(times=times, trans=trans, shape=(n,n))

where the function _hermitian is as follows:

def _hermitian(a):
"""Computes the Hermitian transpose of a vector or a matrix
"""
return jnp.conjugate(a.T)

The great feature of JAX is that when it just-in-time compiles Python code, it can remove unnecessary operations. E.g., if d is a real vector, then _hermitian is a NOOP and can be optimized out during compilation. All operators in cr.sparse.lop have been carefully designed so that they can be easily JIT-compiled. We provide a utility function lop.jit to quickly wrap the times and trans functions of a linear operator with jax.jit.

T = lop.jit(T) 

After this, T.times and T.trans operations will run much faster (by one or two orders of magnitude).

Something like A^H A for the normal equation above can be modeled as a function:

gram = lambda x : T.trans(T.times(x))

where it is assumed that T is already created and available in the closure.

Preconditioned Conjugate Gradient with linear operators

Once, we have a framework of linear operators handy with us, it can be used to write algorithms like preconditioned conjugate gradient in JAX compatible manner (i.e. they can be JIT-compiled). This version is included in cr.sparse.opt.pcg.

CR-Sparse contains a good collection of algorithms for solving inverse problems using linear operators.

A Compressive Sensing example

We consider a compressive sensing example which consists of Partial Walsh Hadamard Measurements, Cosine Sparsifying Basis, and ADMM based signal recovery. In compressive sensing, the data size is much lesser than the model size. Thus the equation A x = b underfits. Finding a solution requires additional assumptions. One useful assumption is to look for x which is sparse (i.e. most of its entries are zero).

Here is our signal of interest x of n=8192 samples.

A non-sparse cumulative random walk signal (similar to stock markets)

We will use a Type-II Discrete Cosine Orthonormal Basis for modeling this signal. Please note that normal DCT is not orthonormal.

Psi  = lop.jit(lop.cosine_basis(n))

Let’s see if the signal is sparse on this basis:

alpha = Psi.trans(x)
The representation of x in the orthonormal discrete cosine basis

It is clear that most of the coefficients in the discrete cosine basis are extremely small and can be safely ignored.

We next introduce a structured compressive sensing operator which takes the measurements of x in Walsh Hadamard Transform space but only a small m=1024 number of randomly selected measurements are kept. The input x may also be randomly permuted during measurement.

from jax import random
key = random.PRNGKey(0)
keys = random.split(key, 10)
# indices of the measurements to be picked

p = random.permutation(keys[1], n)
picks = jnp.sort(p[:m])
# Make sure that DC component is always picked up
picks = picks.at[0].set(0)
# a random permutation of input
perm = random.permutation(keys[2], n)
# Walsh Hadamard Basis operator
Twh = lop.walsh_hadamard_basis(n)
# Wrap it with picks and perm
Tpwh = lop.jit(lop.partial_op(Twh, picks, perm))

We can now perform the measurements on xwith the operator Tpwh. The measurement process may also add some Gaussian noise.

# Perform exact measurement
b = Tpwh.times(x)
# Add some noise
sigma = 0.2
noise = sigma * random.normal(keys[3], (m,))
b = b + noise
Random measurements on x using a structured compressive sensing matrix based on Walsh Hadamard Transform.

We can now use the yall1 solver included in CR-Sparse for recovering the original signal x from the measurements b.

# tolerance for solution convergence
tol = 5e-4
# BPDN parameter
rho = 5e-4
# Run the solver
sol = yall1.solve(Tpwh, b, rho=rho, tolerance=tol, W=Psi)
iterations = int(sol.iterations)
#Number of iterations
print(f'{iterations=}')
# Relative error
rel_error = norm(sol.x-xs)/norm(xs)
print(f'{rel_error=:.4e}')

The solver converged in 150 iterations and the relative error was about 3.4e-2.

Let’s see how good is the recovery.

Please see here for the full example code.

Summary

In this article, we reviewed the concept of linear operators and the computational benefits associated with them. We presented a functional programming-based implementation of linear operators using JAX. We then looked at the application of these operators in compressive sensing problems. We could see that sophisticated signal recovery algorithms can be implemented using this approach which is fully compliant with the JAX requirements for JIT compilation. We aim to provide an extensive collection of operators and algorithms for inverse problems in CR-Sparse.

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Shailesh Kumar
Shailesh Kumar

Written by Shailesh Kumar

Python | JavaScript | Web Applications | Math | Statistics

No responses yet