Simple Transformers — Multi-Class Text Classification with BERT, RoBERTa, XLNet, XLM, and DistilBERT

Simple Transformers is the “it just works” Transformer library. If you are looking to use Transformers for real applications (in 3 lines of code), without worrying about the technical details, this is for you.

Thilina Rajapakse
Oct 13 · 5 min read

Preface

The Simple Transformers library is built on top of the excellent Transformers library by Hugging Face. The Hugging Face Transformers library is the library for researchers and other people who need extensive control over how things are done. It is also the best choice when you need to stray off the beaten path, do things differently, or do new things altogether. Simple Transformers is, well, a lot simpler.

Introduction

Simple Transformers is designed for when you need to get something done and you want it done now. No mucking about with source code, no hours of hair-pulling while trying to figure out how to even set the damn thing up. Text classification is so common that it should be easy, right? Simple Transformers thinks so and is here to do exactly that!

One line to set up the model, another to train the model, and a third to evaluate. Honestly, how much easier could it be?

All source code is available on the Github Repo. If you have any issues or questions, that’s the place to resolve them. Please do check it out!

Installation

  1. Install Anaconda or Miniconda Package Manager from here
  2. Create a new virtual environment and install the required packages.
    conda create -n transformers python pandas tqdm
    conda activate transformers
    If using cuda:
    conda install pytorch cudatoolkit=10.0 -c pytorch
    else:
    conda install pytorch cpuonly -c pytorch
    conda install -c anaconda scipy
    conda install -c anaconda scikit-learn
    pip install transformers
    pip install tensorboardx
  3. Install simpletransformers.
    pip install simpletransformers

Usage

Let’s see how we can perform Multiclass Classification on the AGNews Dataset.

For Binary Classification with Simple Transformers you can refer to this article.

  1. Download the dataset from Fast.ai.
  2. Extract train.csv and test.csv and place them in a directory data/.

Simple Transformers requires data to be in Pandas DataFrames with at least two columns. You can simply name your columns text and labels, and SimpleTransformers will take care of handling the data. Alternatively, you can follow the convention below.

  • The first column contains the text and is of type str.
  • The second column contains the labels and is of type int.

For multiclass classification, the labels should be integers starting from 0. If your data has other labels, you can use a python dict to keep a mapping from the original labels to the integer labels.

This creates a ClassificationModel that is used for training, evaluation, and prediction. The first parameter is the model_type, the second is the model_name, and the third is the number of labels in the data.

  • model_type may be one of ['bert', 'xlnet', 'xlm', 'roberta', 'distilbert'].
  • For a full list of pretrained models that can be used for model_name, please refer to Current Pretrained Models.

To load a model a previously saved model instead of a default model, you can change the model_name to the path to a directory which contains a saved model.

model = ClassificationModel('xlnet', 'path_to_model/', num_labels=4)

A ClassificationModel has a dict args which contains many attributes that provide control over hyperparameters. For a detailed description of each attribute, please refer to the repo. The default values are given below.

Any of these attributes can be modified when creating a ClassificationModel or when calling its train_model method by simply passing in a dict containing the key-value pairs to be updated. An example is given below.

That’s all you have to do to train the model. You can also change the hyperparameters by passing in a dict containing the relevant attributes to the train_model method. Note that, these modifications will persist even after training is completed.

The train_model method will create a checkpoint (save) of the model at every nth step where n is self.args['save_steps']. Upon completion of training, the final model will be saved to self.args['output_dir'].

To evaluate the model, just call eval_model. This method has three return values.

  • result: The evaluation result in the form of a dict. By default, only the Matthews correlation coefficient (MCC) is calculated for multiclass classification.
  • model_outputs: A list of model outputs for each item in the evaluation dataset. This is useful if you need probabilities for each class rather than a single prediction. Indeed, the prediction is calculated by applying a softmax function over the outputs.
  • wrong_predictions: A list of InputFeature of each incorrect prediction. The text may be obtained from the InputFeature.text_a attribute. (The InputFeature class can be found in the utils.py file in the repo)

You can also include additional metrics to be used in the evaluation. Simply pass in the metrics functions as keyword arguments to the eval_model method. The metrics functions should take in two parameters, the first one being the true label, and the second being the predictions. This follows the sklearn standard.

For any metric functions that need additional parameters (f1_score in sklearn), you can wrap it in your own function with the additional parameters added and pass your function to eval_model.

For reference, the results I obtained with these hyperparameters are as follows:

{'mcc': 0.937104098029913, 'f1': 0.9527631578947369, 'acc': 0.9527631578947369}

Not bad considering I didn’t really do any hyperparameter tuning. Kudos to RoBERTa!

In real-world applications, we often have no idea what the true label is. To perform predictions on arbitrary examples, you can use the predict method. This method is fairly similar to the eval_model method except that this takes in a simple list of text and returns a list of predictions and a list of model outputs.

predictions, raw_outputs = model.predict(['Some arbitary sentence'])

Conclusion

Multiclass Classification is a common NLP task in many real-world applications. Simple Transformers is a painless way to apply the power of Transformers to real-world tasks, without needing a PhD in Artificial Intelligence!

In the Works

I have plans to add Question Answering to the Simple Transformers library in the near future. Stay tuned!

The Startup

Medium's largest active publication, followed by +526K people. Follow to join our community.

Thilina Rajapakse

Written by

AI researcher, serial procrastinator, avid reader, fantasy and Sci-Fi geek, and fan of the Oxford comma. https://www.linkedin.com/in/t-rajapakse/

The Startup

Medium's largest active publication, followed by +526K people. Follow to join our community.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade