Deep Learning with Julia

A brief tutorial on training a Neural Network with Flux.jl

DSB
Coffee in a Klein Bottle
5 min readJan 30, 2021

--

Flux.jl is the most popular Deep Learning framework in Julia. It provides a very elegant way of programming Neural Networks. Unfortunately, since Julia is still not as popular as Python, there aren’t as many tutorial guides on how to use it. Also, Julia is improving very fast, so things can change a lot in a short amount of time.

I’ve been trying to learn Flux.jl for a while, and I realized that most tutorials out there are actually outdated. So this is a brief updated tutorial.

1. What we are going to build

So, the goal of this tutorial is to build a simple classification Neural Network. This will be enough for anyone who is interested in using Flux. After learning the very basics, the rest is pretty much altering Networks architectures and loss functions.

2. Generating our Dataset

Instead of importing data from somewhere, let’s do everything self-contained. Hence, we write two auxiliary functions to generate our data:

Visualizing the Dataset

3. Creating the Neural Network

The creation of Neural Network architectures with Flux.jl is very direct and clean (cleaner than any other Library I know). Here is how you do it:

The code is very self-explanatory. The first layer is a dense layer with input 2, output 25 and relu for activation function. The second is a dense layer with input 25, output 1 and a sigmoid activation function. The Chain ties the layers together. Yeah, it’s that simple.

4. Training our Model

Next, let’s prepare our model to be trained.

In the code above, we first organize our data into one single dataset. We use the DataLoader function from Flux, that helps us create the batches and shuffles our data. Then, we call our model and define the loss function and the optimization algorithm. In this example, we are using gradient descent for optimization and cross-entropy for the loss function.

Everything is ready, and we can start training the model. Here, I’ll show two way of doing it.

Training Method 1

In this code, first we declare what parameters are going to be trained, which is done using the Flux.params() function. The reason for this is that we can choose not to train a layer in our network, which might be useful in the case of transfer learning. Since in our example we are training the whole model, we just pass all the parameters to the training function.

Other then this, there is not much to be said. The final line of code is just printing the mean prediction probability our model is giving.

Training Method 2

This method is a bit more convoluted, because we are doing the training “manually”, instead of using the training function given by Flux. This is interesting since one has more control over the training, which can be useful for more personalized training methods. Perhaps the most confusing part of the code is this one:

The function gradient receives the parameters to which it will calculate the gradient, and applies it to the loss function, that is calculated for the batch d. The splater operator (the three dots) is just a neat way of passing x and y to the loss function. Finally, the update! function is adjusting the parameters according to the gradients, which are stored in the variable gs.

5. Visualizing the Results

Finally, the model is trained, and we can visualize it’s performance again the dataset.

Neural Network prediction again the training dataset

Note that our model is performing quite well, it can properly classify the points in the middle with probability close to 0, implying that it belongs to the “fake data”, while the rest has probability close to 1, meaning that it belongs to the “real data”.

6. Conclusion

That’s all for our brief introduction. Hopefully this is a first article on a series on how to do Machine Learning with Julia.

Note that this tutorial is focused on simplicity, and not on writing the most efficient code. For that learning how to improve performance, look here.

TL;DR
Here is the code with everything put together. You can also find it on my Github repository.

This article was last updated on October 2nd 2021. The code here was tested on Julia v1.6.3 and Flux v0.12.7.

--

--