How to Fine-Tune an NLP Transformer Model
You can also access it here, free from paywalls.
Thereās been a lot of buzz around Natural Language Processing, or NLP, the last few years after important technological advances that has allowed more performant models even in situations with limited access to data. This literally exploded in November 2022 when OpenAIās ChatGPT was launched. As a result, Iād like to take the opportunity to show how you can fine-tune a pre-trained model on a task of your choice on your own.
Additionally, I will leverage both structured and unstructured data by processing both of it in the same model architecture. This can yield important performance improvements on some tasks.
As an example, I will use the disaster dataset which can be downloaded from Kaggle. You probably wonāt simply download your data like this for a real project but rather spend significant amount of time preparing it by querying databases or accessing APIs though. Still, it serves our purpose in this case.
Analyse and Clean the Data
Our task is to build a model to predict whether a tweet is about a real disaster or not. The data contains the following columns:
id
- a unique identifier for each tweet.text
- the tweet itself.location
- the location the tweet was sent from (may be blank).keyword
- a particular keyword from the tweet (may be blank).target
- only included intrain.csv
, denotes whether a tweet is about a real disaster (1) or not (0).
I will additionally only use the data in the train.csv
file since the test dataset doesnāt contain any labels. The id
column can be excluded as it doesnāt contain any predictive value.
Itās a fairly small dataset with around 7600 samples. There are missing values in both the keyword
(<1%) and location
(>33%) column, which we can replace with something as simple as no_keyword
and no_location
.
There are a total of 110 duplicate tweets, with some labels not being consistent between these. This may cause problems during model training as the model won't know which label to trust. We could look up all the duplicate tweets individually and correct their labels. However, for the sake of this example, I will go with a more simplistic approach and drop them.
Clean up the keyword
and location
columns
The keyword
column contains over 200 unique words, while the location
column make up of more than 3300 different locations. Displaying the top 15 highest counts for each results in the following.
We can replace strings such as %20
, which seem to be the only undesired characters, with a space in the keyword
column.
The location
column is very inconsistent; sometimes it refers to a continent, sometimes a country and sometimes a city. Additionally, there's nonsense data such as World Wide!!
, Live On Webcam
and milky way
(not shown in the plot though).
The column is probably not so useful in its current state. Thereās a lot of things we can do to extract more meaningful information from it, but one simple approach is to use a library to extract real cities or countries and use that as input to our model. This library will probably make several mistakes (such as not recognising a city name, or falsely interpreting an abbreviation as a city or country name), but it might still be better than what we have now. Depending on how much time we want to spend on this, the result will likely vary. There are many libraries available for this, each with their pros and cons. I will use Geotext
since it's comparably fast. Other, likely better, options are spaCy
and geograpy3
.
If several cities are found, use the first, if no city is found, get the country, otherwise fill with no_location
. This is a very naive approach, but still shows what can be done in terms of extracting a location from a text. Other, perhaps more thoughtful approaches might be to map the location with coordinates or geographical areas instead.
def extract_location(text):
""" Function for extracting the first city from a location text.
If no city exist, the first country is extracted.
If none of above exist, 'no_location' is returned.
"""
places = GeoText(text)
cities = places.cities
countries = places.countries
if len(cities) == 0:
if len(countries) == 0:
return 'no_location'
else:
return countries[0]
elif len(countries) == 0:
if len(cities) == 0:
return 'no_location'
else:
return cities[0]
else:
return 'no_location'
# Find city and country names in the location column
data['location'] = data['location'].apply(extract_location, axis=1)
# Clean up the keyword column
data['keyword'] = data['keyword'].apply(lambda x: re.sub("%20", " ", x))
This brings down the unique number of locations in the location
column to 726 instead of over 3300.
Target Variable
From before, we know that the target variable doesnāt contain any Null
s. As shown, itās a fairly evenly distributed dataset with 57% belonging to the Not Disaster
and 43% to the Real Disaster
class. Had one of the classes been significantly over represented, we would need to take some measures such as applying over- or under sampling, or think more deeply about different evaluation metrics.
Feature Engineering
By cleaning the tweet text before extracting information such as the length of the tweet, number of punctuations, hashtags, etc., we might lose important information. For that reason, we will first extract additional features before cleaning the text. Having said that, itās worth to experiment with the opposite approach as well.
Weāve already extracted relevant locations from the location
column which we hope will improve the classifier. However, most of the useful information is probably contained in the tweets themselves. Perhaps, tweets that in general have longer words and fewer punctuations might be an indication for real disaster tweets.
Thereās certainly a lot that can be done here, and I wonāt do any thorough feature engineering other than creating some few features to display what can be done. For example:
- Number of punctuations in the tweet (
!"#$%&\'()*+,-./:;<=>?@[\\]^_\{|}~
) - Tweet length
- Average word length per tweet
- Word count in each tweet
- Number of stop words
- Number of hashtags in the tweet
Iām using the string
and nltk
libraries to get common punctuations and stop words in the english language.
# Count the number of punctuations
data['nbr_punctuations'] = data['text'].apply(lambda x: len([p for p in str(x) if p in string.punctuation]))
# Get tweet length
data['tweet_length'] = data['text'].apply(lambda x: len(x))
# Get average word length per tweet. Round to 3 decimals
data['avg_word_length'] = data['text'].apply(lambda x: np.round(np.mean([len(c) for c in str(x).split()]), 3))
# Get word count
data['word_count'] = data['text'].apply(lambda x: len([len(c) for c in str(x).split()]))
# Get number of stop words in each tweet
stop_words = nltk.corpus.stopwords.words('english')
data['nbr_stopwords'] = data['text'].apply(lambda x: len([p for p in str(x) if p in stop_words]))
# Get number of hashtags in a tweet
data['nbr_hashtags'] = data['text'].apply(lambda x: len([c for c in str(x) if c == '#']))
Applying that results in the following dataframe.
Now when weāve used the original text to create some additional features, we can clean it up a little. Depending on what modelling approach we take, we might choose to clean the text more or less. For example, sequence models often do very well with only minor data cleaning, while bag-of-words models tend to prefer slightly more. We will only do minor text cleaning to keep it simple.
def text_cleaner(t):
""" Function for normalising a text by converting it into lowercase,
removing unicode characters and replacing urls with 'link'.
"""
# lowercase the text
t = t.lower()
# remove unicode characters
t = re.sub(r' [^a-z0-9]+ ', ' ', t)
t = t.encode('ascii', 'ignore').decode()
# replace links with 'link'
t = re.sub(r'http*\S+', 'link', t)
return t
# Normalise, or clean, the tweet text by applying above function
data['text'] = data['text'].apply(text_cleaner, axis=1)
Visualise the distribution of the engineered features
In order to get a better overview of the features we just engineered, itās a good idea to plot them. We can do that using histograms. By adding ranges to the title, we get a more exact overview of each variableās distribution.
Takeaways from above plot:
- We note that, in general, the distributions are pretty similar among both
Not Disaster
andReal Disaster
throughout all variables. The exception might beavg_word_length
, whereReal Disaster
in general seem to have longer words. This might perhaps be explained by that news papers and journalists are posting about real disasters more frequently, and they might be using less slang and more sophisticated words than the general public. However, the difference in distribution should be verified using statistical methods such as a T-test or Kruskal-Wallis. - Most variables seem to be non-normally distributed. The exceptions might be
word_count
and perhapsnbr_stopwords
. This, again, should be verified using statistical methods such as a Normaltest or Jarque-Bera. - The most frequent tweet length is around 140 characters, while the longest in the dataset is 157. 99% are shorter than 143 characters.
Real Disaster
tweets might be slightly longer on average. A plausible explanation could be the same as for the difference inavg_word_length
. - Most tweets have no hashtags and rather few punctuations.
Correlations
In order to get a better sense of which of the engineered features contribute most to the target variable, we can calculate the Pearson correlation and display it in a matrix.
Several of the features have a (positive) correlation with the target variable, where the tweet_length
, avg_word_length
, nbr_stopwords
and nbr_punctuations
are the strongest. In general, as the value of these features increase, the probability for a real disaster also increases.
location
has a very weak (negative) correlation with the target. While disasters can strike everywhere, there's probably more that can be done to extract valuable information from this feature.
In general, there seem to be little multi-correlation between features. Thatās good, because if it becomes too high, it might negatively affect the model performance.
Itās important to note that above correlations only take each variable into account separately. Itās possible that two or more weakly correlated variables might be very important together if combined.
Model Fine-Tuning
There are two main approaches we can take when building the classifier; 1) a more traditional bag-of-words model (often machine learning), and 2) a sequence model (i.e. deep learning). The Transformer architecture is probably the most popular sequence model for NLP today. Depending on the size of the dataset, tweet length and perhaps the importance of context in the tweet, each approach may have its advantages.
Although we ideally should evaluate both approaches, I will choose a Transformer model approach to keep it simple. Iām also suspecting that itās important for the model to understand the sentence context and its meaning in order to perform as well as possible on the task. Transformer models have a tendency to perform slightly better in such situations. I highly encourage you to read FranƧois Cholletās Deep Learning with Python, 2nd Ed. to learn more about this and Deep Learning in NLP in general.
Although BERT might be the most famous Transformer model out there, I will choose an ALBERT architecture instead. Itās in many ways similar to BERT and differ mainly in that it shares parameters across layers. This leads to a lighter, faster to train and often more performant model. To save training time, we will additionally choose a smaller ALBERT architecture. Obviously, we will need to train the model on a GPU. GPUs often gives at least 10x speed-ups compared to CPUs for tasks like this due to parallelisation.
Although some of the features showed little correlation with the target variable, I will use them all. They might play more importance when āworkingā in combination with the other features. Additionally, they take up little space compared with the tweet text itself.
A very naive model that only predicts the largest class each time would get 57% accuracy. This will thus be our target to beat.
First of all, we need to split the dataset into train and validation splits. Although we already have a test set, we canāt evaluate the model on it because it doesnāt have any labels. We should therefore split the train set into an additional third set; a test set. However, since the dataset is fairly small, we would likely need to implement cross validation in order to accurately assess the model performance. This will take quite a lot of time though (training a model once takes around 20 minutes on a free GPU, with 10-fold CV, we would spend over three hours on it, or 200 minutes). Although suboptimal, we will thus evaluate the model performance on the validation set only.
One way to prepare the data for a Transformer model when thereās both text, categorical and continuous columns, is to combine them all into one single column and separate them with the [SEP]
token. We can also include this step in the pre-processing pipeline. In order to more clearly display the results, I will go with the first approach.
Using š¤ Hugging Faceās TabularConfig object also works very well when combining structured and unstructured data.
# Combine features into one string. Seperate each with the token '[SEP]'
data['features'] = data[feats].astype(str).apply(' [SEP] '.join, axis=1)
Applying above and selecting only the resulting features (which are now all in the same column) and corresponding target, yields the following. Note how each individual feature is separated with the [SEP]
token ā a standardised token used in ALBERT.
Pre-processing Steps
The ALBERT model needs some further pre-processing of the data. Among other things, each word needs to be tokenised while the sentence needs to be trimmed to the same length. The pre-processing step is downloaded from TensorFlow Hub and we can then combine it all in the following function.
# Specify ALBERT preprocessing layer URL
albert_preprocess = 'https://tfhub.dev/tensorflow/albert_en_preprocess/3'
def make_albert_preprocessing_model(sentence_features, seq_length=145):
""" Returns Model mapping string features to ALBERT inputs.
"""
# Convert each input to a TensorFlow Input layer
input_segments = [tf.keras.layers.Input(shape=(),
dtype=tf.string,
name=ft)
for ft in sentence_features]
# Tokenize the text using ALBERT preprocessing layer
bert_preprocess = hub.load(albert_preprocess)
tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name='tokenizer')
segments = [tokenizer(s) for s in input_segments]
# Pack inputs and trim to 'seq_length'
packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
arguments=dict(seq_length=seq_length),
name='packer')
model_inputs = packer(segments)
return tf.keras.Model(input_segments, model_inputs)
Build the Model
Next, define a function that leverages a pre-trained ALBERT model base. Make sure that we allow fine-tuning of it by specifying trainable=True
and stack a single Dense
layer on top which outputs one of two classes; 1
or 0
, representing disaster or non-disaster. Additionally, we can use a commonly used Adam
optimiser that often works great out of the box.
To get a more holistic picture of the modelās performance, we measure three metrics; accuracy
, precision
and recall
apart from the loss
. We are extra interested in precision
and recall
since those metrics tell us how well the model classifies real disasters.
# Specify ALBERT model base URL
albert_base = 'https://tfhub.dev/tensorflow/albert_en_base/3'
def build_classifier_model():
# These are the inputs the model will accept
inputs = dict(
input_word_ids = tf.keras.layers.Input(shape=(None,),
dtype=tf.int32),
input_mask = tf.keras.layers.Input(shape=(None,),
dtype=tf.int32),
input_type_ids = tf.keras.layers.Input(shape=(None,),
dtype=tf.int32),
)
# Load the pre-trained ALBERT layer
albert_encoder = hub.KerasLayer(albert_base,
trainable=True,
name='albert_encoder'
)
# Stack one Dense layer on top of the ALBERT output.
net = albert_encoder(inputs)['pooled_output']
net = tf.keras.layers.Dense(1, activation='sigmoid',
name='classifier')(net)
# Create a final model by combining above
model = tf.keras.Model(inputs, outputs=net, name='ALBERT_classifier')
# Compile model with Adam optimiser and binary crossentropy as loss
model.compile(tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='binary_crossentropy',
metrics=['accuracy',
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall')])
return model
The resulting ALBERT model has 11.7 million parameters and looks as follow. Note the three inputs the model is expecting, the pre-trained model in the middle and the single output layer.
Data Loading and Model Fine-tuning
Create a function for loading the dataset into the model in batches. Itās important to load the data in batches for memory reasons. Although it could be possible to load this rather small dataset into the GPU memory directly, a solution like that wouldnāt scale well as the data grows larger in size.
def load_dataset(pd_dataset, batch_size, preprocessing_model):
# Convert the features and target into TensorFlow tensors
dataset = tf.data.Dataset.from_tensor_slices(
(pd_dataset['features'].values,
pd_dataset['target'].values)
)
# Shuffle the data, make sure we can repeatedly train over it,
# and make sure it returns 'batch_size' nbr of samples each time
dataset = dataset.shuffle(len(pd_dataset))
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
# Apply the preprocessing steps on the data
dataset = dataset.map(lambda ex0, ex1: (preprocessing_model(ex0), ex1))
# Cache the data for performance
dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
return dataset
Next, we will specify model parameters and path, create the preprocessing model for data preprocessing and load the data in batches. We can use a seq_length
of 145 characters to capture the whole length of over 99% of the tweets (143 is enough as we saw before, but 145 is a more even number). Longer sequence lengths lead to longer training times, but also potentially more performant models because more of the information in the tweet is captured.
Specify a model checkpoint that saves the best model based on validation accuracy during training. That way we can easily access the most performant model afterwards. Although a metric such as the F1 score might be more in-line with our goal, we will use validation accuracy as itās easier to understand.
# Specify parameters
batch_size = 16
steps_per_epoch = len(train_set) // batch_size
validation_steps = len(validation_set) // batch_size
model_path = 'best_model'
# Create preprocessing model
preprocessing_model = make_albert_preprocessing_model(['features'],
seq_length=145)
# Load the training and evaluation data in batches
train_data = load_dataset(train_set[['features', 'target']],
batch_size,
preprocessing_model)
validation_data = load_dataset(validation_set[['features', 'target']],
batch_size,
preprocessing_model)
# Specify model checkpoint and save best model only
model_checkpoint = [tf.keras.callbacks.ModelCheckpoint(
filepath=model_path,
monitor='val_accuracy',
save_best_only=True,
mode='max')
]
Lastly, initiate the model training/fine-tuning with previously defined parameters. Since we actually are fine-tuning the model, we donāt need, nor should, train it for long. I choose 5 epochs as it doesnāt take too long while it also seem to be enough for the performance to flatten out. Depending on the GPU youāre using, this will take different amount of time. In my case, using a free GPU, it took around 20 minutes.
# Initiate model training
history = classifier.fit(train_data,
validation_data=validation_data,
epochs=5,
callbacks=model_checkpoint,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps
)
Analysing Model Performance
To get a holistic view of the modelās performance after each epoch, we can plot each metric at the end of each epoch on the train and validation data.
Focusing on val_accuracy
, we note that thereās a peak after the third epoch before it then declines slightly. Itās a fairly small dataset and we are only fine-tuning the model. Chances are that it starts overfitting after the third epoch even though weāre using a low learning rate which results in larger differences between the train and validation set in later epochs.
By loading the best performing model after three epochs, we can take a deeper look into its performance using a confusion matrix and a classification report.
We see that the model is doing fairly well in correctly predicting both real disasters and non-disasters. It correctly identifies 494 of the 638 disasters while also correctly identifying 753 non-disasters (out of 850). It does seem to do a little worse on real disaster tweets though. Although thereās surely still room for improvements, this first model does fairly well.
We could move forward by looking into the tweets the model fails on. The tweets might even be very hard for a human to correctly classify, they might have incorrect labels in the first place etc., which will negatively affect the modelās performance. There might also be a pattern among the tweets it is miss-classifying. If so, we could collect more tweets like that to improve the performance. As of now though, weāre happy with these results.
Summary
Although thereās certainly more we can do in terms of building more features, experimenting with various text cleaning approaches, using more powerful models, analysing the results etc., etc., weāve already learned some interesting things when fine-tuning Transformer models on both structured and unstructured data. Hereās a short summary of what weāve done:
- Weāre only working with the train set since the test set doesnāt contain any labels. This results in around 7600 tweets.
- We only apply basic cleaning to keep as much as possible of the original information. For example, tweets that contain many spelling mistakes might less likely be written by journalists, and thus possibly less likely to be about real disasters. We use a third-party library to extract and normalise the locations in the
location
column. Much more work can be done on this though. - We created a couple of new features in the hope to better separate the two classes. Although much more can be done here, we experimented with some new features such as
nbr_punctuations
,tweet_length
,avg_word_length
,word_count
,nbr_stopwords
andnbr_hashtags
. - Initial visual observations indicate that the distributions between the two target classes are fairly similar for the created features. The exception might be
avg_word_length
, whereReal Disaster
in general seem to have longer words which we hypothesised perhaps more frequently are written by journalists. - The most frequent tweet length is around 140 characters, while the longest is 157. 99% are shorter than 143 characters.
- We use a pre-trained ALBERT base in our model to leverage its embeddings and to speed up training.
- By combining both structured (the engineered features) and unstructured data (the tweets), we attempted to boost the model performance.
- Although the model does slightly worse on real disaster tweets, an accuracy on the validation set of close to 84% is achieved.
Have feedback or want to connect?
Iām open to feedback and thoughts of any kind. Leave a comment below or connect with me on LinkedIn or my personal website.