Contextual brand safety is an ongoing series. This is the second blog in this series. Through this series, we talk about steps to be taken to do multi-label text classification in the industry. This blog post talks about model training and evaluation.
Brand safety is an important offering of GumGum. Contextual Brand Safety-I talks about the problem and data preprocessing techniques in depth. In this blog post, we will discuss model training, evaluation and steps to production.
2. Experimental setup
We set up a multi-step mlflow project tracking system to track and store artifacts across each step i.e,
- Data loading and preprocessing (1 step)
- Model training (1 step for each layer)
We used Weights & Biases for hyperparameter tuning and better visualization. We used an AWS GPU enabled EC2 instance, as a compute resource for running parallel experiments, connected to an AWS EFS volume to store our experiment results. We used fastai and Pytorch framework for rapid prototyping and experimentation.
Although we tried and tested several machine learning and deep learning models, we will talk about the two best models; ULMFiT and BERT and the experiments conducted with them.
Universal Language Model Fine-tuning for Text Classification
Inductive transfer learning has greatly impacted computer vision, but existing approaches in NLP still require…
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations…
3. Experiments with ULMFiT
3.1 Language model experiments:
ULMFiT is an LSTM based language model which has broadly three steps:
- Pretraining language model on Wikipedia
- Fine-tuning language model on production data
- Training a classifier on production data
In all our experiments, we used the language model pretrained on Wikitext 103. We conducted experiments while fine tuning the language model with two different hyperparameters:
- drop out multiplier: This is a multiplier to the dropouts that are default with the model
- vocabulary size: The number of words used in the vocabulary used to build the model
Drop out Multiplier
We conducted experiments with two different corpus sizes namely 100,000 and 250,000 records. The plot below shows the variation of Language model performance with respect to the variation of dropout multiplier
In the plot above, we can observe the following things:
- When dropout is increased, with different data sizes, we observe degradation in performance
- We also observe that, with a higher amount of data, the degradation slows down
Even though, we can see that with higher data size, there is a perceivable difference in performance. Since the evaluation set for both 100k and 200k are different but from the same distribution, it is inconclusive whether there is any difference and if there is by how much.
We conducted another experiment to understand the variation of LM performance with respect to data size. The plot below shows the variation of Language model performance with respect to the variation of vocabulary size.
In the plot above, we can observe the following things:
- With a corpus size of 100K, we observe that performance increase initially but then decrease gradually whereas with a corpus size of 250K, performance decrease monotonically with respect to vocabulary size.
Increase in vocabulary can be seen as an increase in complexity, therefore, there was an initial increase in LM performance but it later decreased due to overfitting. But in the case of corpus size of 250K, with a large amount of data, the model is able to generalize with fewer features aka small vocabulary.
3.2 Experiments with Classifier
We conducted a total of 104 experiments while training our classifier in stage 3 of language model development. During this stage, we used bayesian hyperparameter tuning to come up with the best set of hyperparameters for the task at hand.
- Language model encoders
- Final layer dropouts
- hidden layer size in the classifier
Note: Encoder naming convention: An language model trained with a dropout multiplier of 0.3 is named as fwd_enc_fp16_0.3_AWD
- An encoder trained with a higher value of dropout multiplier has the highest performance when trained with a lower classifier dropout and vice versa
- The best performing encoder may or may not be a part of the best performing classifier.
Keep diverse equally performing encoders from the finetuning stage and include them as a hyperparameter tuning of the classifier to get the best performance.
Characteristics of the dataset, encoder and classifier hyperparameter need to be tuned together because even though we develop the classifier in 3 steps. Each step is dependent on the outcomes of previous steps.
4. Comparison of ULMFiT Vs BERT
We used fastbert library for our experimentation. BERT model was fine tuned using the same dataset. Unlike ULMFiT, BERT does not need the unsupervised fine tuning, the fine-tuning in BERT is equivalent to supervised training in ULMFiT. We compared the performance of the best ULMFiT against BERT on two different datasets, one a multi-label dataset that consists of labels not only to distinguish between safe and unsafe but also to detect different types of threat in articles as per the nomenclature in Series-I.
1.Multi-label performance (Micro-averaged)
This dataset is less imbalanced with respect to all the classes and labelled for all types of threatening content
The performance was calculated on a dataset that mimics our production distribution. Only 10%–15% of the entire corpus has threatening content.
In the dataset-1, we can see that BERT has a better recall, which is good for our task but the recall is worse on the production dataset (dataset-2). Even though BERT being the state of the art (SOTA) model, it is still unable to outperform ULMFiT. The performance could be said to be similar if not worse. This trend could be attributed to the following reasons:
- Transformers outshine LSTMs on medium to long text because they are able to capture the longer-term dependencies. The median text length of our production traffic is about 50 words, which is not quite long. LSTMs are able to handle it pretty well.
2. ULMFiT encoders were fine tuned in an unsupervised way to capture the language differences between its pretrained corpus, Wiki-103 and the production traffic whereas BERT was not given any such special treatment.
Memory footprint & Throughput comparison
Throughput of a system is measured in the number of completed messages per second. ULMFiT is about 10 times faster than BERT. This means that deploying BERT in production will cost 10 times more than ULMFiT.
Both ULMFiT and BERT were quantized before throughput measurement.
The memory footprint of BERT is about 2x more than ULMFiT. This is also a cause of concern because when we asynchronously process messages as a batch, we could increase the size of the batch to scale the system better. Thus, having a model with lesser memory footprint will help in scaling.
- By following the fine tuning procedure as per the ULMFiT paper, we were able to achieve performance on par with state of the art models such as BERT.
- Quantized of ULMFiT has a faster inference time than BERT.
- Performance of BERT could be improved by unsupervised finetuning on production corpus before supervised training.
- Explore distillation to speed up BERT inference with minimal loss in quality, such as DistilBERT.
- Amortization of BERT encoder cost by using multi-task learning using BERT as a centralized encoder for all NLP downstream tasks.
In this blog post, we have discussed the different experiments we conducted and the steps we took to achieve a model that has state of the art performance in quality and which also scales well. In the next blog post under this series, we will discuss the steps we took for putting this model in production. In the meantime, check out my previous blog on contextual brand safety to understand the task better.
Contextual Brand Safety solution- I
Contextual brand safety is an ongoing series. This is the first blog in this series. This blog post sets the ground by…