A Web App to Recognize Handwritten Digits — in 19 Lines of Python

Abubakar Abid
The Startup
Published in
3 min readJul 13, 2020

Have you ever trained a machine learning model that you’ve wanted to share with the world? Maybe set up a simple website where you (and your users) could try putting in their own inputs and seeing the models’ predictions? It’s easier than you might think!

In this tutorial, I’m going to show you how to train a machine learning model to recognize digits using the Tensorflow library, and then create a web-based GUI to show predictions from that model. You (or your users) will be able to draw arbitrary digits into a browser, and see real-time predictions, just like below. For the web-based GUI, we’ll be using the Gradio libary.

The kicker is that we’ll be able to do it in just 19 lines of code (including import statements!) Let’s get started!

Installation

If you don’t already have Tensorflow, instructions to download it are here. Either Tensorflow 1 or Tensorflow 2 will work fine. If you don’t already have Gradio, download it by running pip install gradio.

Loading the MNIST dataset

The first step is that we load a dataset of handwritten digits, along with corresponding labels. We can do this using the tensorflow library like this, where we’ve downloaded the data and rescaled it, so that it’s ready to serve as the input to the neural network.

I’ve also imported the gradio library, which we’ll use later. 5 lines of code down, 14 more to go.

Let’s train a model

Now, let’s train a simple neural network. We can do this quite easily with the keras submodule within tensorflow. Here, in the first 5 lines, we define the architecture of our model, which consists of a single hidden layer with 128 neurons. We then “compile” the model with an appropriate loss function, optimizer, and choose which metrics to display during training.

With the final line, we train the model on our dataset. If you run this, you’ll should see output that looks like this, showing that the model has been trained to roughly 97% accuracy on the validation dataset.

We could certainly train a better model, by perhaps choosing a convolutional neural network and training for more epochs, but we won’t worry about that in this tutorial.

Let’s create a GUI!

Now, it’s time for the fun part. Let’s create a user interface around our model. The gradio library requires you to define 3 things: a prediction function, an input UI component, and an output UI component. In our case, we can use the built-in Sketchpad component for the input, and the Label component for the output (we’ll set up the Label to show the top 3 classes).

The resulting code looks like this:

You’ll notice that we also used live=True , which allows us to get real-time predictions from our model, and capture_session=True , which is needed for backwards compatibility for Tensorflow 1 (if you’re using Tensorflow 2, it doesn’t hurt to leave this line in).

Let’s launch it!

Finally, we’re ready to launch our interface. No creativity involved in this step!

Running this, you should have the GUI pop up on a local server (if you’re running locally) or on a public port (e.g. if you’re running within a colab notebook). Once you open it, you’ll see a UI where you can start making predictions. Have fun! You can also explicitly set share=True as part of the launch() to create a public link you can share with friends and family.

Have fun experimenting! If you’d like to run this code right now, check out this colab notebook with all of the code.

More about Gradio: if you want to build more cool web applications and GUIs, read about Gradio here: https://github.com/gradio-app/gradio

--

--