Multiclass Text Classification From Start To Finish

So you have some text and you want to classify it. So you have multiple classes for your text and you want to classify it. Well, what are you waiting for?

I’ll be using python and scikit-learn and as always, my jupyter notebooks can be found on GitHub along with the original dataset.


Text classification is a supervised learning technique so we’ll need some labeled data to train our model. I’ll be using this public news classification dataset. It’s a manually labeled dataset of news articles which fit into one of 4 classes: Business, SciTech, Sports or World.

This is what the dataset looks like:

Exploratory Data Analysis & Text Processing

Lets look at how many articles we have per class:

All of the classes are perfectly balanced which is something you will almost never find in the wild so I will take a sub sample of the business and sports categories to make it imbalanced (i.e. more realistic). I’ll do 1K from Business and 800 from Sports.

I’ll also hold out 5 articles from each category to use for predictions at the end to evaluate how well the classifiers did on unseen data which is the true test.

Let’s visually inspect the json file:

Looks like we have some html in there that needs to be removed. We’ll take that out in the text processing step. Let’s look at the avg. word count by category:

Ok well those are definitely short articles lol. Let’s look at the word count distribution by category to see if anything is off there:

The distributions look pretty uneventful for a word count and nothing crazy is going on here between categories which is good. I would imagine the main challenge here is going to be the shorter articles, especially among the sampled categories but we’ll see. There’s just not a ton of information in articles with less than 20 words but it is what it is and each dataset will bring it’s own surprises so let’s roll with it.

Next we’ll do the text processing to be able to look at the most frequent words/bigrams. The processed text will also be what we use to create the features. Text processing is extremely important and not something I’m going to go into detail about here because this is more about the classification task. In short, the text will be tokenized, lower cased and lemmatized. It will have punctuation, numbers and stop words removed. The contractions will also be expanded out. Through the magic of TV:

Now let’s look at the top words by category:

There seems to be distinct differences between the content so that’s good. These also take bigrams into account (i.e. new_york). N-grams are words that frequently appear together. Bigrams are two words (i.e. artificial intelligence), trigrams are three words etc. The words ‘artificial’ and ‘intelligence’ by themselves have different meanings than when they appear together.

Let’s look at the top bigrams now by category:

Again we see nice distinct differences in the content and the labels seem representative of what the content is.


Creating a good labeled training set with correct labels can be a whole separate topic and something I will not cover here in order to focus on the classification problem. In short though, garbage in garbage out so it’s important to have good labels or you will never get the results you want.

The labeling here is obviously done for us so all we need to do is turn those text labels into numbers for the classifiers. Scikit-learn’s label encoder is helpful for this. You can also just map these with a dictionary but that may get tedious if you have a lot of categories.


We need to create features from the text. In order to do this, we need to turn the words into numbers because machines like numbers. The features will be created from the processed text, not the raw text.

There’s a few ways to turn words into numbers, one is to create a matrix simply with the count of the words by article. You would use count vectorizer for that. Another way is to use tf-idf which stand for ‘term frequency-inverse document frequency’. It sounds fancy but it’s pretty straight forward and that’s what I’ll be using because I’ve had good results with it. There’s also other ways like word embeddings so feel free to experiment.

The ‘term frequency’ is just the number of times a word appears in a document, divided by the total number of words in that document. The ‘inverse document frequency’ is the logarithm of the number of total documents divided by the number of documents the word appears in. Pretty straight forward.

This gives you a weight which is a good indicator of how important a word is to a document out of all the documents. That last part is important and gives context to the article.

For example, any article may contain the word rainbow but the wikipedia article about rainbows contains that word a lot. Therefore, if we’re looking for articles about rainbows (because face it, who isn’t?) out of all articles on the internet, then that may be the be one we want. It will also have a lot of other frequently occurring words that appear when describing rainbows (light, color, spectrum, double, etc.) that will help with the classification.

Double Rainbow Photo by Corey Hearne on Unsplash

Sklearn’s tf-idf vectorizer has a few other helpful features we will use, especially ‘ngram_range’ and ‘max/min_df’. We’ll use an ngram_range of 1, 2 which takes unigrams and bigrams into account. The max_df tells the algorithm to ignore words that appear in more than the threshold set. We’ll set it to .95 which will ignore words that appear in 95% or more of all documents. The min_df ignores words that appear in less that the threshold. We’ll set it to 2 which will ignore words that only appear in 2 or less documents. This will help get rid of frequent terms (max_df) and infrequent terms (min_df).

You can gridsearch the tf-idf features just like you can gridsearch the hyperparameters for the classification model which I suggest you do. For now though, these are a good place to start.

Dimensionality Reduction

Even though we removed stop words and are applying thresholds to the tf-idf vectorizer, it still leaves us with a lot of unique words (~15K), many of which we probably don’t need and are redundant. So, let’s also do Latent Semantic Analysis (LSA) which is a dimensionality reduction technique. LSA uses SVD or Singular Value Decomposition (and in particular Truncated SVD) to reduce the number of dimensions and select the best ones.

“LSA is known to combat the effects of synonymy and polysemy (both of which roughly mean there are multiple meanings per word), which cause term-document matrices to be overly sparse and exhibit poor similarity under measures such as cosine similarity” Source

How does it help us select the best features? At a very high level you’re taking your matrix of tf-idf weights and creating 3 separate matrices with it through factorization. The eigenvalues will help you determine what are the most important dimensions, and which ones you don’t necessarily need. If data is highly correlated, we should expect many singular values to be small and can be ignored. We also need to manually select the number of dimensions per category. A good starting point is 100 so that’s what we’ll go with.

Model Selection and Evaluation

Alrighty then. We are ready to start looking at some classifiers! Exciting! Everyone has their own process but the first thing I like to do is try a bunch of different kinds of classifiers and compare them with the default parameters. The huge caveat here is that an algorithm may not perform well right out of the box but will with the right hyperparameters. This step however will give you a good preliminary understanding of which types of classifiers will inherently work better but I generally do not dismiss any outright at this step.

I selected 6 different classifiers to test out along with sklearn’s dummy classifier which is just random chance as a baseline. With 4 categories you would expect the accuracy to be around .25 and it is.

In terms of the metrics to use to evaluate the different classifiers we’re looking at:

  • Accuracy - simply the fraction of samples predicted correctly
  • Precision - the ratio of true positives to false positives or the ability of the classifier not to label a positive sample as negative
  • Recall - the ratio of true positives to false negatives or the ability of the classifier to find all the positive samples
  • F1 Score - The harmonic average of precision and recall
  • Precision-Recall Curve (graph): It shows the trade off between precision and recall. A high area under the curve (AUC) represents both high recall and high precision. High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall)

With all of the above metrics, the closer to 1 the better with 0 being the worst. Usually when the class distribution is unbalanced (like we have here), accuracy is considered a poor choice as it gives high scores to models which just predict the most frequent class so F1 score is a better choice in our case. We will look at the precision-recall curve by category later once we have our final candidates for model selection.

For multiclass classification you also need to select the type of averaging for these metrics as they are calculated per class. The discussion of what’s best and what you should use depends on your situation and is beyond the scope of this article but in general I find macro averaging the most useful (which is what I’m using here). It computes F1 scores for each class and returns the average of those scores. Remember though, the real test is how they perform on unseen articles.

Random Forest has the highest F1 score (.77), followed by AdaBoost (.74) and the SGD (.73). For the sake of brevity I’m going to continue with just two classifiers, Random Forest and SGD which implements a logistic regression. I’ve found SGD with logistic regression works well for text classification because it can deal with sparse data like we have with text.

The objective of gradient descent is to find the best parameters to optimize a given function. In our case that would be the loss function of which we are trying to find the local minimum. It does this iteratively by taking steps proportional to the negative of the gradient of the function. The downside of this is that the gradients have to be computed for each data point in the training set which can be an issue. If we have large amounts of data, it could potentially not fit into memory.

One solution to this is stochastic gradient descent, which doesn’t compute the true gradient descent but selects one random example from the training set. After several passes like this and with shuffling the data, the algorithm can converge must faster, be jut as accurate and the data can safely fit into memory.

One of the most important parameters to tune for stochastic gradient descent is the learning rate. The learning rate is essentially how big of a ‘step’ the algorithm takes in the direction of the local minimum. If the step is too large, it can miss the local minimum by constantly ‘stepping over it’ from each side of the function which is convex. If the step is too small it can take a long time to find the local minimum. Here is a good visual example:


Random Forests are an ensemble method which trains a bunch of decision trees (hence the ‘forest’) to try and combat overfitting which decision trees are prone to. It does this by averaging the predictions of the different decision trees which can reduce the variance. Tuning models is all about finding the best balance between bias and variance.

Some important parameters to start with in terms of tuning for RF’s are max depth (the deeper the tree the more prone to overfitting, too shallow and it will underfit), the number of estimators or trees (the more the better but it comes at the expense of computing time and there is a point of diminishing returns) and max features (too many features will lead to overfitting, a good starting place is the square root of the total number of features).

Hyperparameter Tuning

The next step I usually do is to tune the hyperparameters for the classifiers I want to explore. Remember we are just using the default parameters in our initial assessment so they are not performing the best that they can.

In terms of hyperparameter tuning, that again is a whole subject in and of itself but basically you just want to search through a decent representation of the values for the different hyperparameters available to see which work best. I’’ll use sklearn’s gridsearch with k-fold cross-validation for that.

In k-fold cross validation, the data is split into k folds (usually 5 which is what I’ll use). 1 out of the 5 splits of the data is used for testing and the other 4 are used for training. This happens k times and each time a different fold is used as the test set etc. The results are then averaged. Gridsearch simply goes through all possible combinations for all the values for each hyperparameter you give it and returns the best one based on a score.

I’m going to cheat a little here since this can potentially take a long time, especially if you’re just starting out. Once you do some tuning you will get a feel for which hyperparameters are really improving your model based on your dataset.

To cheat, I’m going to use DataLab which is a jupyter notebook that lives on a virtual machine on google cloud platform. I’m going to use 96 virtual cpus which speeds things up quite a bit. 🏃


I’m using SGD for a text classifier in production currently so I had a head start on which hyperparameters work best and these values are close to what I am using so definitely a good place to start.

Second Evaluation

OK now that we have the best hyperparameters for each model, let’s look at how they perform again. Below is the ROC curve for SGD with the micro and macro averages along with each class:

The closer to the left the better the classifier performs so overall it’s doing a good job. You can see it’s struggling with class 0 (green) which is Business and doing the best with class 2 which is Sports. Now the same for RF:

Same story here. After tuning both models they are performing almost exactly the same on the test data.

Next I like to look at the confusion matrix to see where the classifier is mixing up (confusing) categories. First for SGD:

You can see that the F1 score at the top increased from .73 to .82 after tuning which is a good increase. In the ROC curve we saw that Business was the category it struggled with the most. It looks like it’s confusing it the most with SciTech articles and not at all really with Sports. Sports in general is the easiest for the classifier to deal with which is no surprise since it’s the most distinct from the others.

For RF it’s a similar story so we know that it’s confusing Business with SciTech the most. Even though Sports has the fewest samples, it’s the easiest to distinguish because the content is so different.


So now we are finally ready for the last step, prediction! Exciting! I will evaluate the models we selected and tuned with data they have never seen before to see how they perform.

Below are the 5 articles from each class we held out at the beginning along with the predictions from the models that have been trained on the full dataset with the correct class.

SGD performs better and had an accuracy of 75% which is inline with the testing and RF only had an accuracy of 55% so it does indeed look like RF is overfitting the data. There are some measures you can take against that like controlling the complexity of the trees in your forest or even pruning when they grow too much.

Remember we sampled the Business and Sports category to create imbalanced classes. If you didn’t do that below is how these would have performed:

SGD’s accuracy on unseen articles increases to 90% which is great and RF increases to 75%. As you can see, gathering more training data is a great way to increase performance (up to a certain point). But it’s not always possible.

Next Steps

So what can you do to improve performance if you can’t increase the size of your training set? There’s usually a few things in any situation which would be:

  • Try some different models - AdaBoost would be a good candidate to also try as well as maybe K-NN. They performed well out of the box and may perform much better once tuned. We also didn’t delve into any neural networks which may or may not improve the performance but are worth a shot. For something like text classification, I would definitely try more traditional algorithms first before going the deep learning route because of the increased training time, increased cost/need for gpus etc.
  • Try different features - We used tf-idf to create the features from the words but some have had success with word embeddings (Word2vec or GloVe) so you can give those a try. Maybe even give count vectorizer a try
  • Try and supplement with additional features - Besides the actual text of the article, sometimes you will also have additional data like author or publication which can help. Adding in the title to the body of the article if you have it can also increase performance since it usually has a lot of pertinent info
  • Gridsearch the processing and dimensionality reduction - We didn’t tune those at all so that can also help improve performance