(Hugging Face) — Text classification, going further than the tutorial
Hugging Face is an American company that works in AI and Machine Learning, but this is also a platform with a lot of great resources (including librairies) to help you use or even create (and share) your own machine learning models.
One of those resources is the following tutorial page that teaches you how to fine-tune a BERT (*) based model to classify a text (a movie review) into one of two categories that corresponds to a sentiment (positive or negative).
Could this tutorial be used as a base to classify a text in one of a quite big number of categories? That was the question.
After a preamble about text-classification models, this article will expose to you what was tested and what lessons were learned.
(*) BERT stands for Bidirectional Encoded Representations from Transformers, they were introduce in 2018 by Google and are great for some tasks like text classification as each word is considered both with the content before it and the context after it (the Bidirectional part in the name).
Preamble: what is a text-classification model? Why use one?
The text-classification model
We will focus on single-class classification models.
Those are created from fine-tuning a base model (like BERT or its smaller version DistilBERT (*)) constructed from a large document set. A layer is added (or the last one replaced) and trained using a dataset constructed for a specific task (text-classification in this case).
They are trained to provide n output values, where n is the number of classes. Each one of those values is a floating point value between 0 and 1, the greater the value, the greater the probability that the text evaluated belongs to the associated class.
So even a single-class text-classification model does not provide directly a single result, but does give a score for each class. At the end we consider as the result the class with the highest output value.
(*) Many BERT like models are available on HuggingFace, the differences are mostly about the data they were trained from (if you are not working with English sentences you can look for language specific models like
CamemBERT for French) and if they have been distilled (when a smaller model is trained from an original one in order to achieve nearly the same quality but with less memory/processing power required).
Why is it useful?
Text classification can be an effective way of automatically choosing the most appropriate treatment to use for a given text.
In simple scenario, this can be used to find the most suitable person/service to answer a question, where to store a file on a disk, …
In more complex use-case it can be used as the first step of a complex processing: select the value for a filter and/or choose the most appropriated task to perform, …
The first step to train or fine-tune a model: have a dataset to work with
Have a database / datasource with what you want to have in your dataset
To test with quite a big number of categories, it was decided to construct a quite big dataset using one content source: arXiv.org. It is a free service that provides access to more than 2 million scientific articles about different fields, a big thanks to them.
There were 155 categories in the official list when the database construction was started. For each one of them, a paginated query retrieved the articles it contains. It was quite a long process knowing there is a 3 second delay to respect before two API calls by the Terms of Use.
In the end the database (MongoDB in this case) contains for each article, its ID, its title, category and summary, but only titles and categories will be really used.
Convert the database / datasource to a Hugging Face Dataset
A row in the dataset is basically a dictionary. There are multiple ways to construct a complete Dataset, like (non exhaustive list):
- from a csv file,
- from a pandas DataFrame,
- from a list of dictionaries,
- from a json file,
- from a function returning a generator,
- …
As the datasource in this case was a database, it is this last one that was used.
Notice the class_encode_column and train_test_split method that are used those aren’t mentioned in the original tutorial that use a ready-to-use dataset fetched from Hugging Face.
- class_encode_column(“label”), to call before train_test_split, will convert the “label” column to a ClassLabel, that is to say a data type suitable for text classification.
- train_test_split(test_size=0.2, stratify_by_column=”label”), will split the dataset into “train” and “test” part, with the “test” part having 20% of the items from the original dataset (and “train” the remaining 80%). The stratify_by_column=”label” parameter is here to say that the splitting mechanism should respect the proportions in the distribution of elements for a given class (ie: if there were 10 elements with class “A” in the original dataset, in the end there should be 2 “A” in the test dataset split and 8 “A” for the train dataset split).
Addendum
- 155 different categories where expected in the database after having populated it, but the results contained only 152 different categories from the expected list and some that weren’t in the said list.
- It would have be better to train a model for summary classification, as titles does not always really reflect the content of the article, but it would have taken too much time and so the title was used instead (nearly one year training against three days for the used hardware).
The second step: the training itself
For this part we will start by the end and then add the required elements one by one.
The training itself
After having constructed a Trainer object,
- train() method is used to launch the training (use optional value resume_from_checkpoint=True if you are resuming a training)
- save_model() to save the final model
Prepare the training parameters and elements
The functions used to compute evaluation metrics
This is one of the points where the source code is quite different from the one provided by the original tutorial.
During first training tests the memory consumption was slowly but steadily increasing to a point where the OS was forced to kill the process (seems like that even with available disk space for swapping a process can’t take more than 100GB of virtual memory on an M1 Mac).
So after a bit of search on internet and multiple tests, the chosen solution was:
- to provide a preprocess_logits_for_metrics callback to work only with the most probable class instead of having values for all of them,
- to not process the accuracy in the compute_metrics callback on the whole evaluated dataset, but do proceed by batch. (see the eval_accumulation_steps and batch_eval_metrics parameters from the TrainingArguments object) (*)
(*) as a matter of fact it might have been more appropriate here to use directly a customized evaluation strategy instead of relying on the ones provided par Hugging Face evaluate library.
When you are ready, launch the training and enjoy your night/week-end/vacations.
Training such a model can be very long, depending on the size of the dataset, the samples text length, and last but not least, the available hardware.
For the model this article is talking about, the last training lasted about 80 hours in total, including both the training itself and the evaluation phase that occurred on the end of each of the 10 epochs. (One epoch is having used all the samples in the dataset for the training, with the provided parameters it will be done 10 times).
How to use the model
To use the model, you can create use the pipeline function, it will gives you a callable that you can use with:
- a single text value,
- a list of values,
- a Dataset,
- an iterator,
- …
Depending to what input type you use, you will either get a list of values or an iterator.
Evaluation of the model
Using numeric metrics
Usual text classification metrics are calculated for each label rather than trying to have a global value directly.
For a given label, we can consider 4 base indicators :
- TP: True Positive, how many times this label was correctly predicted,
- TN: True Negative, how many times we correctly predicted the sample was not about this label,
- FP: False Positive, how many times this label was wrongly predicted,
- FN: False Negative, how many times we wrongly predicted another label instead of this one.
We can then calculate four metric from those indicators:
- recall is about how often is the label correctly predicted when it should be,
- precision is about how much of the predictions for this label were correct,
- f1 is a mix of recall and precision that takes both factor into consideration,
- accuracy is how often is this label correctly predicted or not predicted for all content of the evaluated dataset.
What metrics is best adapted depends on the intended usage. I recommend you this page to find out more.
After having computed the metrics you want for all the labels, you might consider having a global indicator using either mean values or weighted mean values. You might prefer using the latter one if there is no uniform distribution of elements across labels.
(You have access to a link with the full stats at the end of the article)
Using a visual representation of the quality of the model
As the aim of creating this model was to test whether training such a model would work, it was interesting to count the number of labels predicted from a given label. The idea was to be able to determine the model’s biases and whether or not they made it unusable.
For a small number of labels, a Chord chart can be used, for a larger number prefer using a Sankey chart. Both are diagrams useful to display a flow. Here what is interesting are the flows between expected labels and predicted one.
More explanations below the chart.
- On the left the categories present in the evaluation dataset. A higher bar means a greater number of samples with this category.
- The right part represents the number of prediction for the category.
- Lines are drawn between the expected categories and the predicted one. The width of the line correspond to the number of predictions.
- Horizontal lines represents good prediction while the curved lines are the bad predictions.
- Categories are grouped by their prefixes, each group having a distinct color. For better visualization, when the lines are between categories of different groups, they are highlighted by using a higher opacity.
What can we conclude for the graph? While there are bad classifications, those does not seems to be completely off the mark. The most important mistakes are about “stat.ML” articles being wrongly labeled as “cs.LG” considering that both those categories are about Machine Learning, we can’t really call it mistakes of classification.
So the model does appear to be quite good.
Full evaluation data as a json
For the curious ones
https://gist.github.com/jeromediaz/7f26f0ed00339999a6abf02e8cf00896
That’s all folks!