Improve Neural Networks by using Complex Numbers

Can Complex Functions be the next breakthrough in Computer Vision?

Devansh
Geek Culture
9 min readNov 17, 2022

--

Join 31K+ AI People keeping in touch with the most important ideas in Machine Learning through my free newsletter over here

Recently, someone in my LinkedIn network shared this very interesting paper with me. Titled, “CoShNet: A Hybrid Complex Valued Neural Network using Shearlets”, this paper proposes the use of complex functions in a hybrid neural network. If you are very confused by those words, don’t worry I was too. In this article, I will explain the idea of hybrid neural networks and how they can be used to improve traditional Convolutional Neural Networks. Then we will cover how using Complex Functions can be used to boost the performance of these models even further. This is going to be a very fun one.

The resulting network is called Complex Shearlets Network (CoShNet). It was tested on Fashion-MNIST against ResNet-50 and Resnet-18, obtaining 92.2% versus 90.7% and 91.8% respectively. The proposed network has 49.9k parameters versus ResNet-18 with 11.18m and use 52 times fewer FLOPs. Finally, we trained in under 20 epochs versus 200 epochs required by ResNet and do not need any hyperparameter tuning nor regularization.

-In case you’re looking for a reason to be excited about the this idea.

Understanding Convolutional Neural Networks

Convolutional Neural Networks have been the OG Computer Vision Architecture since their inception. In fact, the foundations of CNNs are older than I am. CNNs were literally built for vision.

The feature extraction is the true CNN revolution. Taken from IBM’s Writeup on ConvNets

So what’s so good about CNNs? The main idea behind Convolutional Neural Nets is that they go through the image, segment by segment, and extract the main features from it. The earlier layers of the CNN often extract the more crude features, such as edges and colors. However, adding more layers allows for feature extraction at a very high resolution of detail.

CNNs use the sliding window technique to build their feature maps. As you can see, Good Machine Learning requires good software engineering. Image Source

This article goes into CNNs in more detail. For our purposes one thing is important: CNNs have been the go-to for Computer Vision primarily due to their ability to build feature maps. Even with the rise of Vision Transformers, CNNs have held strong (provided you modernize the pipeline using the techniques given below).

So far, so good. So what’s the catch? There is one problem with their approach. The convolutions (building feature maps) can get really expensive.

Enter Hybrid Neural Networks

If you’ve studied even a bit of Computer Science (which you should do be effective at ML), you will realize something about the Feature Mapping process. It is really expensive. You have to slide the window across multiple times. As we’ve already stated, the earlier layers only extract the crude features. The high-resolution features are only spotted at the later levels. This is where some really smart people saw an opportunity. What if we did some Math to find a function that can help us spot some low-level features directly? This way we can spot the features without going through the expensive early convolutions-

In a hybrid neural network, the expensive convolutional layers are replaced by a non-trainable fixed transform with a great reduction in parameters.

If you could find a good function, then you’ve significantly reduced your computational overhead. And we have some great functions that can do this. Turns out Complex Functions just work better. Look at the image below and the difference in results.

This image is the perfect segue into the next section. Let’s now talk about all the advantages that Complex Functions bring to our Neural Networks, and why they work so well in the first place. Some of this can get pretty mathematical, but if you feel that way, make sure you close your eyes and think about what the Deep Learning bros on Twitter tell you about not needing Math for Machine Learning. True Machine Learning is about overfitting big models to neat data, and not this technical mathy stuff (that involves a lot of experimentation).

So let’s get into Complex Functions in Hybrid Networks (and specifically the Complex Shearlets function).

The basic idea behind Hybrid NNs and this paper

The amazing CoSh Network

Before I get into the details, here is a concise look at some of the amazing things this network can accomplish. This should tell you why I’m covering this idea (and hopefully illustrate why I spend my weekend reading random ML papers).

You already know I’m very excited about these results. A cost-effective ML solution built using Math? One that generalizes very well? I’m getting excited just typing this. However, one thing that really stood out to me was this network's resilience to noise and perturbation. This is something that I’ve been covering since I started writing and these results are very exciting as a way to counter that.

Take a look at this graph where they tested the network with permutations of clean and perturbed datasets. The results are shockingly stable, especially considering the relatively small training dataset size. I normally expect this robustness from bigger datasets.

Fanboying out of the way, why does this happen? What is the reason that this can work so well? Is this a fluke, or is there something about Complex Function that works very well?

If we can understand what makes these amazing results tick, we can create much better solutions.

Let’s move on to why Complex Functions might be the next leap in Deep Learning.

This is in stark contrast to a recent paper [41] “.. the necessity to optimize jointly the architecture and the training procedure: ..having the same training procedure is not sufficient for comparing the merits of different architectures.” Which is the opposite of what one wants to have — a no-fuss, reliable training procedure for different datasets and models.

— The authors show that tuning and expensive search is not the only way.

The Magical Properties of Complex Functions

There are some very interesting properties that make Complex Neural Networks special. First, let’s talk about the decision boundaries. Complex Neurons create the following boundaries-

Nothing surprising here. However, this brings up some interesting properties, especially with generalization. According to the authors-

The decision boundary of a CVnn consists of two hypersurfaces that intersect orthogonally (Fig. 7) and divides a decision region into four equal sections. Furthermore, the decision boundary of a 3-layer CVnn stays almost orthogonal [27]. This orthogonality improves generalization. As an example, several problems (e.g. Xor) that cannot be solved with a single real neuron, can be solved with a single complex-valued neuron using the orthogonal property

The next stand-out to me is the presence of Saddle-Points. Saddle points occur in multivariable functions. They are critical points where the function attains neither a local maximum value nor a local minimum value.

Image Source

Why does this matter? At saddle points, the derivatives of loss functions are still equal to 0. However, as the authors state, “SGD with random inits can largely avoid saddle points [29] [30], but not a local minimum.” This behavior is probably what allows for much faster convergence since the algorithms won’t get stuck in local minima. Such an approach provides very similar benefits to the integration of Random Restarts to sample a larger search space. The authors even mention that this CoShNet doesn’t need data augmentation to reach Stable Embeddings (with respect to perturbation).

If you have experience with split-ReLU, let me know.

Both of these properties act in the same direction- they allow the network to achieve more with much less.

There is one final property that deserves its own section. Time to get into Phase Congruency and how it helps in adversarial robustness.

Phase Congruency

In electronic signaling, phase is a definition of the position of a point in time (instant) on a waveform cycle. Phase can also be an expression of relative displacement between or among waves having the same frequency (source). This video provides a visual representation. Phases are very important in signal processing.

If the Phase can stay stable after perturbation, then we can extract stable features. This aligns well with the analysis MIT paper I shared earlier on why perturbation happens. “CoShRem can extract stable features — edges, ridges and blobs — that are contrast invariant. In Fig 6.b we can see a stable and robust (immune to noise and contrast variations) localization of critical features in an image by using agreement of phase.

Gradients fluctuate wildly across scale but phase remains very stable at critical parts of the image. This makes phase a great base for detecting important features.

When it comes to detecting features (and their magnitudes) in images where perturbation applies, this works very well.

“Fig 4 shows despite the considerable perturbations (blurring and Gaussian noise), CoShRem remain stable to most of the characteristic edges and ridges (two step discontinuity in close proximity).”

This phase congruency works wonders in creating models that are robust. I would be interested in seeing how this performance stacks up against more specialized adversarial networks (like the One Pixel Attack). That would be a true test of robustness.

I would like to talk more about this, but a lot of this is related to signal processing. And I know nothing about that. I know enough Math to look through and understand the major ideas/derivations but I’m not fully confident I understand some of the details about phase and complex wavelets. If you have any experiences/resources on this topic leave them in the comments. I’d love to learn from you.

I’ll be looking more into complex functions and analysis after this paper, because it seems extremely powerful. Expect a follow-up with more details/ideas on how complex functions might be usable in networks. If you have any questions/clarifications, you can reach out to Manny Ko. He is a Principal Engineer at Apple and one of the authors of this paper. He shared this writeup with me, and definitely knows more than me about this subject.

If you liked this write-up, you would like my daily email newsletter Technology Made Simple. It covers topics in Algorithm Design, Math, AI, Data Science, Recent Events in Tech, Software Engineering, and much more to make you a better developer. I am currently running a 20% discount for a WHOLE YEAR, so make sure to check it out. Using this discount will drop the prices-

800 INR (10 USD) → 533 INR (8 USD) per Month

8000 INR (100 USD) → 6400INR (80 USD) per year

You can learn more about the newsletter here

Reach out to me

Use the links below to check out my other content, learn more about tutoring, or just to say hi. Also, check out the free Robinhood referral link. We both get a free stock (you don’t have to put any money), and there is no risk to you. So not using it is just losing free money.

To help me understand you fill out this survey (anonymous)

Check out my other articles on Medium. : https://rb.gy/zn1aiu

My YouTube: https://rb.gy/88iwdd

Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y

My Instagram: https://rb.gy/gmvuy9

My Twitter: https://twitter.com/Machine01776819

If you’re looking to build a career in tech: https://codinginterviewsmadesimple.substack.com/

Get a free stock on Robinhood: https://join.robinhood.com/fnud75

--

--

Devansh
Geek Culture

Writing about AI, Math, the Tech Industry and whatever else interests me. Join my cult to gain inner peace and to support my crippling chocolate milk addiction