An Introduction to Training Neural Networks

Ana Lozada Behaine
2 min readMay 17, 2024

--

Welcome! If you are new to building neural networks, you will enjoy this basic introduction to machine learning. Its a hands-on tutorial that will walk you through a classic example of fitting a neural network, Tensorflow, Keras, and the MNIST dataset.

Packages

First we are going to import the necessary packages. Tensorflow is a machine learning framework developed by Google and Keras was a framework which was added to Tensorflow to simplify the creation of layers in a neural network. This cell imports those libraries along with matplotlib and numpy — a numerical library for python.

Neural networks are universal function approximators which means we can use them to learn to map any input to any output. For our case, we will be doing image classification using the MNIST handwritten digit dataset. First lets import the dataset.

Visualize the Data

Now lets visualize the first training data item, which is the number 5.

png

Train the model

We are going to perform supervised classification against the target Y values. Just for fun, we are going to set up the model to return the 10 final output logits so we can have a look at them.

Epoch 1/25
1875/1875 [==============================] - 7s 4ms/step - loss: 0.8818
Epoch 2/25
1875/1875 [==============================] - 7s 3ms/step - loss: 0.2277
Epoch 3/25
1875/1875 [==============================] - 7s 3ms/step - loss: 0.1578
...
Epoch 25/25
1875/1875 [==============================] - 7s 4ms/step - loss: 0.0569

Now that the model is trained, lets take a look at those output logits. The logits correspond to the prediction digits 0–9. The one logit with the largest value will be the value returned in the argmax function.

1/1 [==============================] - 0s 16ms/step
[[ 99.14376 13.202532 60.50213 114.660706 -283.11496 127.66067
104.22364 -25.091702 99.30108 89.00973 ]]
the model classifies this as a 5

Finally, lets evaluate the prediction accuracy of the model against the test data. Unfortunately, due to the way we set up the model, we cannot simply run model.evaluate. Instead we will call model.predict, then convert the logits into probabilities, and finally call argmax to extract the yhat classifications.

313/313 [==============================] - 1s 2ms/step
test accuracy 0.9764

The model achieves 97% classification accuracy on the held out test set!

If you made it to the end, congratulations on completing this introduction to modeling with neural networks.

This tutorial was adapted from a lecture from Andrew Ng’s class on Advanced Learning Algorithms.

Resources

--

--