Handwritten Character Recognition Web App with EMNIST
Long before the ancient Egyptians created hieroglyphs 𐦚, humans have been jotting down their thoughts and ideas onto physical objects.
Fast forward a couple of millennia, with the advent of computers and the internet, almost everything we do utilizes this machine in some sort of way. However, handwriting and taking notes is long from dead, on the contrary, it has been revitalized. There are softwares such as Google Lens or Apple’s Live Text, that recognize handwritten text and allow the user to copy that text onto a digital notepad. These softwares are OCR-powered apps that employ an artificial intelligence for accurate and fast recognition.
For those of you who are on a journey in the field of Machine Learning, this is a guide for creating a web app for classifying handwritten digits and letters. Luckily for us, a standardized dataset has been created for this exact project, and it is called the EMNIST Dataset.
The EMNIST Dataset
The Extended MNIST Dataset or EMNIST Dataset is a set of handwritten letters and digits in a 28 by 28 pixel format. Derived from the MNIST Dataset, which is considered the go-to standard for machine learning benchmarks, the EMNIST dataset presents a greater challenge for ML models. Included in this dataset are 62 different classes, which comprise 26 + 26 upper and lower case letters, as well as the 10 digits. This dataset is readily available and easy to use, allowing us to create a decent ML model while spending minimal efforts on preprocessing and formatting.
To get started, this dataset is available in 6 different splits:
- EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
- EMNIST ByMerge: 814,255 characters. 47 unbalanced classes.
- EMNIST Balanced: 131,600 characters. 47 balanced classes.
- EMNIST Letters: 145,600 characters. 26 balanced classes.
- EMNIST Digits: 280,000 characters. 10 balanced classes.
- EMNIST MNIST: 70,000 characters. 10 balanced classes.
Each split has been created for different purposes and for different models, however the one that I used is the by_class split. This includes all the 62 classes in the dataset — 10 digits + 26 lowercase letters + 26 uppercase letters, along with the full breadth of images and data, without any reductions.
Convolutional Neural Networks
For this project, I used a Convolutional Neural Network, so before jumping into the code, a brief explanation of CNN’s are required.
Convolutional Neural Networks, CNNs in short are a deep learning algorithm that is extremely useful for Computer Vision. A CNN can “understand” the sophistications of complex images better than any other algorithm. There are 2 main parts of the CNN algorithm: Convolutions and Max Pooling.
Convolutions are essentially the CNN model detecting patterns in the image by passing many different types of filters over the image.
Here are examples of 3 different layers of filters that are passed over the image; they become more and more complex as to detect more sophisticated patterns in the images. These are visual representations of the filter, but in actual training they are just a matrix of numbers.
The main question to be asked is how a filter is used to convolute an image? So, let me give you a short explanation…
Let’s say we have this 5 by 5 image and we have a 3 by 3 filter/kernel matrix here:
1 0 1
0 1 0
1 0 1
This filter matrix is then scanned across the image and the dot product of a filter-sized patch of the image and the kernel matrix is calculated (the gif above visually represents this process). A dot product is the multiplication followed by a summation of all the corresponding elements of 2 matrices. This value is stored in a new matrix, which represents the convoluted image.
This process is repeated many times with several different filters to pick up on all the features in the image.
CNN’s do not use just a single filter in their training, in fact it’s common for a CNN to learn from 32 to 512 filters in parallel for a single image. All these different filters extract different features which allow the CNN to learn those features.
Another powerful trait is the ability to stack CNN layers on top of each other. This allows for a deeper decomposition of the image and enables the model to pick up on deeper and more complex patterns in subsequent layers.
The convoluted image or the feature map is a representation of the image with a certain feature extracted. Each filter that is applied to the image produces 1 new convoluted image, so if 32 or even 64 filters are used, we end up with 32 times the amount of data. This is a HUGE amount, thus to decrease computational power and to increase efficiency, the convoluted image is summarized and reduced in size. This process is called Max Pooling.
Furthermore, max pooling is another feature extraction method. By summarizing regions in the image, we are left with the sharpest or most prominent feature in the image. Max pooling is incredibly beneficial as it also creates translational invariance in the original image. All the benefits of max pooling have led to the prominent use of this method in most CNN models.
In order to perform max pooling, a stride and a grid need to be specified. The grid is the pool size and the stride is the number of pixels by which the grid is going to slide across the image.
These are the 2 fundamental concepts that CNN’s are built on. Now, let’s move onto the actual EMNIST Classifier.
Before jumping into the creation of an AI model, all the code and files can be found in my Github Repo here: https://github.com/PuravG/EMNIST-Classifier! Now, moving on to the model creation.
The first step in all AI models, rather in all AI projects is to import our libraries. TensorFlow is the framework that we are going to use and this is the most popular tool for machine learning.
Loading the Dataset and Image Pre-Processing
From the EMNIST Dataset, I loaded the by_class split and pre-processed the data. The first step is to normalize the data which is to change the range of the pixels of the data. Images generally have pixels with values from 0 to 255, however we need to normalize these values and bring them into a range from 0 to 1. This is a good practice in machine learning and it has many benefits from increased efficiency to better model performance.
This is called data augmentation, in which I am changing or augmenting the data to create more variance for the neural network. This allows for better generalization across data and increased performance. In our final web app the user will be drawing the character on a canvas with a mouse, so by performing data augmentation, the model will be able to recognize images that are not fully centred or straight, but are instead crooked or rotated.
The code snippet below is an illustrated example of how data augmentation can change or modify the image dataset.
Model Architecture & Training
A machine learning model has various interconnected layers that perform the computations. This is a Convolutional Neural Network and it has 2 Convolutional layers each followed by a Max Pooling layer. The Convolutional Layers are followed by 2 fully connected dense layers with 256 then 128 neurons. Finally, the model has an output layer with 62 nodes. Each node in the final layer represents a class of characters and the model will output 62 different values in the range [0,1], representing the probability that the image belongs to that class.
Compiling the model, then training it for 20 epochs.
So our Final Accuracy is 86.75%! That is not a bad score for this dataset, obviously the question that you may be asking is what happened to the other 13.25%? Well, keep reading to understand the drawbacks and the improvements that can be made to this model.
Model History Graphs
This graph looks good and the steady improvement over time is great. It’s now time to transfer these results into a tangible product.
The first method, is to create a local notebook and predict images from your own system (already ahead of you 🙂; my prediction notebook is here). Yet, imagine having to run that whole notebook every time you need to predict an image. It’s not elegant nor is it user friendly.
Introducing the web app…
Imagine an application that runs on your web browsers which can be accessed by anyone from anywhere. This is a web app and it’s great for sharing and particularly intuitive for the user.
The first step in creating this app is to prep the model for the web.
Save model as tfjs format
Add this code snippet after you are done training to save your model as a .tfjs model.
We won’t be needing that saved file anytime soon, so let’s shift our focus to creating the web app.
HTML is the most basic building block of the web as all websites are created using HTML. Think of HTML as the skeleton of a website — it creates a structure and holds everything together. Using HTML we can create a canvas element that will allow the user to draw the character directly on the app, this element will act as a drawing space, hence the name — canvas.
Creating this canvas is extremely simple as it requires only 1 line of code.
So you have successfully created a canvas! Pretty easy right? Well, it gets a bit trickier.
Initializing the Canvas
The first step uses 2 methods: getElementById and getContext. These are the 2 fundamental methods of getting the state and the properties of the canvas.
So now that you have a canvas ready, you need to use it. Using it and accessing the different states of the canvas requires something that is called an event listener. Event listeners sit silently in the background and keep updating the state of the object.
The if(canvas) is a condition that checks when the canvas has been initialized. Only when the canvas is ready, will the event listeners be initialized as well. These event listeners are listening for different actions that the user can perform and once they do it, the functions that are in those listeners will be executed.
Getting the Position of the Cursor
To draw a line or any shape on the canvas, the program needs to know the exact position of the cursor at all times. To do that, we can create a dictionary to store the position of the cursor and then keep updating it as the cursor moves.
Drawing is a culmination of all the previous functions; it calls and uses them and the end result is the ability to draw on the canvas!
The startPainting and stopPainting are functions that are called in the event listeners. Moving on to the sketch function, if (!paint) return, is a check to make sure that the paint flag is True, and if it is False the function will stop and return nothing.
And there you have it — the ability to draw on a canvas in a web app! Now that we can draw, let’s go to the main purpose of this web app, which is predicting a character from a drawing that the user made.
Integrating the CNN Model
80% of a ML engineer’s time is spent collecting and processing the data, while only 20% is spent coding the AI model and creating predictions. So after completing all that grueling work of creating the website, then the canvas and everything else, here we are at the last step — integration of the CNN model.
Remember that TensorFlow.js model that you saved for later, well it’s time to take it out now.
Loading the Model
The first requirement is to upload your TensorFlow.js saved model to a server such as Github. The program then accesses the file on Github and loads it using the async function.
The character that is drawn on the canvas has the dimensions 280 by 280. This image needs to preprocessed and converted into the required format for the model to create a prediction. First, the tf.browser.fromPixels() method is used to create a tensor that can flow into the first input layer of the model. tf.image.resizeNearestNeighbor() resizes the image into the new shape of (28,28), so that the image can match the data of the original model. tf.mean() converts the image from a coloured image with 3 colour channels into a black and white image. tf.toFloat() function casts the image array to type float. And finally, the tensor.div() function divides the array by the maximum RGB value of 255, effectively normalizing the image.
Predictions and Output
All that is left is to create a function for the model to predict the image. Only when the predict button is clicked will the model first take the image from the canvas, preprocess it and then create its prediction on it.
The output is a dictionary of the class and the corresponding prediction. Therefore, to find the prediction that the model is the most confident on, we must iterate through the dictionary to find the class for which the model has the highest confidence.
Once we find that class, the element on the web app is changed to the prediction and the confidence of the model (in percent) is displayed.
Final Thoughts + Improvements
So you finally created a working and mostly accurate classifier of handwritten characters. After testing and playing with this web app, a pattern should start to come up. This model is not the best 😅… it’s not the greatest at classifying these characters. The final accuracy of the model was only 87%, while that is good for a ML model, it is not comparable to humans or the latest OCR software. So let’s break it down…
The root of this problem actually comes from the way ML models learn. A ML model learns by recognizing patterns in the data and associating those patterns with that specific label. In our case, the CNN extracts the features of the image and those features are what the model learns.
For example in the letter ‘N’, ‘N’ has 2 parallel lines followed by 1 diagonal line. These are the types of features that the model would pick up on, however the letter ‘M’ is very similar to that description. If you think about it the letter ‘M’ is just the fusion of 2 ‘N’s. So the model gets confused between these 2 letters and mixes them up. Some other confusing patterns include N vs W vs M; 1 and I; L and 7; and so many others. Because all these letters are so similar, the model usually gets these images wrong. However letters like ‘B’ or ‘X’ have very distinctive patterns, resulting in an accurate prediction.
So how does handwriting detection software work then? For one, they are trained on huge computers, so they can have up to 15 or more CNN layers, and hundreds of epochs for training. Secondly, whole sentences or words are inputted; the context of the word in a sentence can greatly help with the prediction.
Don’t feel disheartened just yet, as there is a way to circumvent this problem. The solution is transfer learning! You can leverage and use those trained models that big companies or universities have already created and tweak them to your own situation. There are models that have been created by various university researchers which are open source and free for you to use!
Phew! That was a lot, but the main thing is that you finally made it to the end! This was a long, yet very fruitful 🍎 journey with its own set of challenges. I hope that this guide aided you along your journey in machine learning.
Once again here is the link for the Github Repo where you can find all the files and the code for this project. Thank you for reading and have a good one! 👋