The Startup
Published in

The Startup

Photo by Giammarco Boscaro on Unsplash

Fine Tuning BERT for Text Classification and Question Answering Using TensorFlow Framework


Google BERT (Bidirectional Encoder Representations from Transformers) and other transformer-based models further improved the state of the art on eleven natural language processing tasks under broad categories of single text classification (e.g., sentiment analysis), text pair classification (e.g., natural language inference), question answering (like SQuAD 1.1) and text tagging (e.g., named entity recognition).

BERT model is based on a few key ideas:

  • Аttention only model without RNNs (LSTM/GRU) is computationally more attractive (parallel rather than sequential processing of input) and even has better performance (ability remember information beyond just about 100+ words) than RNNs.
  • BERT uses an idea of representing words as subwords or ngrams. On average a vocab of 8k to 30k ngrams can represent any word in a large corpus. This has a significant advantage from memory perspective.
  • Eliminates the need for task specific architectures. A pre-trained BERT model can be used as is for a wide variety of NLP tasks with fine-tuning. This avoids the need for task specific architectures (like ELMo) that we needed before — for example, a model for Q&A would have a very different architecture from a model that solved NER.
  • Word2vec and Glove word embeddings are context independent — these models output just one vector (embedding) for each word, combining all the different senses of the word into one vector. Given the abundance of polysemy and complex semantics in natural languages, context-independent representations have obvious limitations. For instance, the word crane in contexts a crane is flying and a crane driver came has completely different meanings; thus, the same word may be assigned different representations depending on contexts. BERT can generate different word embeddings for a word that captures the context of a word — that is its position in a sentence.
  • Unlike the GPT model, which also represents an effort in designing a general task-agnostic model for context-sensitive representations, BERT encodes context bidirectionally, while due to the autoregressive nature of language models, GPT only looks forward (left-to-right).
  • Transfer learning. This advantage has nothing directly to do with the model architecture — but the fact that these models are trained on a language modeling task (and other tasks too in the case of BERT) they can be used for downstream tasks which have very little labeled data. During supervised learning of downstream tasks, BERT is similar to GPT in two aspects. First, BERT representations will be fed into an added output layer, with minimal changes to the model architecture depending on nature of tasks, such as predicting for each token vs. predicting for the entire sequence. Second, all the parameters of the pretrained Transformer encoder are fine-tuned, while the additional output layer will be trained from scratch.

Here are some very useful articles that helped me to understand various aspects of the BERT model:

  • For any Transformer-based models like BERT, one should start his journey with solid understanding of the attention mechanism, first mentioned in a paper by Dzmitry Bahdanau. The paper aimed to improve the sequence-to-sequence model in machine translation by aligning the decoder with the relevant input sentences and implementing attention.
  • This beautifully illustrated article helped me a lot to understand how attention works, from the very basic concepts to specific details like tensors’ dimensions and computational complexity of each step.
  • Then it’s time for the original paper on Transformer architecture. Here you can watch quite interesting explanation of this paper by Yannic Kilcher
    The main idea of Transformer was to combine the advantages from both CNNs and RNNs in a novel architecture using the attention mechanism. Transformer architecture achieves parallelization by capturing recurrence sequence with attention and at the same time encodes each item’s position in the sequence. As a result, it leads to a compatible model with significantly shorter training time.
  • In this article Lena Voita et al. evaluated the contribution made by individual attention heads in the Transformer’s encoder to the overall performance of the model and analyzed the roles played by them. They identified three functions which heads might be playing:
    1. Positional: the head points to an adjacent token
    2. Syntactic: the head points to tokens in a specific syntactic relation
    3. Rare words: the head points to the least frequent tokens in a sentence.
  • Finally, here comes the original paper (and presentation) on BERT model. I think this video (also by Yannic Kilcher) and this article were the most helpful resources to gain a deeper understanding of the key differences between vanilla Transformer and BERT models.


I’ve made two complete examples of fine-tuning BERT-Base model — for classification and question answering tasks. The BERT-Large model requires significantly more memory than the BERT-Base, so it can not be trained on a consumer-grade GPU like RTX 2080Ti (and RTX 3090 is not yet supported by Tensorflow):

  • BERT-Base: 12-layer, 768-hidden, 12-heads, 110M parameters
  • BERT-Large: 24-layer, 1024-hidden, 16-heads, 340M parameters

The main goal was to write examples in a pure python, combine data processing, training and testing processes in a single script file, and also make this example compatible with pre-trained TF Hub models. This makes these examples easier to understand, adapt to new tasks and keep the model up to date with TF Hub. This project is also available on my GitHub.

Original model could be found here and pre-trained English version is available here on TF Hub. Pre-trained multilingual versions are also available on and

Both examples were trained on RTX 2080 Ti using tensorflow-gpu:2.3.1. The hyperparameters have been adjusted for a reasonable balance between validation accuracy, training time, and available memory.

CUDA devices

The BERT input sequence unambiguously represents both single text and text pairs. In the former, the BERT input sequence is the concatenation of the special classification token CLS, tokens of a text sequence, and the special separation token SEP. In the latter, the BERT input sequence is the concatenation of CLS, tokens of the first text sequence, SEP, tokens of the second text sequence, and SEP.

The BERT model expects three inputs:

  • The input ids — for classification problem, two inputs sentences should be tokenized and concatenated together (please remember about special tokens mentioned above)
  • The input masks — allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the input ids, and contains 1 anywhere the the input ids is not padding.
  • The input types — also has the same shape as the input ids, but inside the non-padded region, it contains 0 or 1 indicating which sentence the token is a part of.

This model returns two outputs:

  • Pooled output — the final hidden state corresponding to the CLS token.
    It is used as the aggregate sequence representation for classification tasks (roughly speaking, it is an embedding for the whole sentence).
  • Sequence output — 768 dimension embeddings for each token in the given sentence.

Sparse categorical cross-entropy loss function is used for both text classification and question answering tasks, as shown below.

Text Classification

For this example I’ve used the GLUE MRPC dataset from TFDS — corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent.

  • Number of labels: 2.
  • Size of training dataset: 3668.
  • Size of evaluation dataset: 408.
  • Maximum sequence length of training and evaluation dataset: 128.
Tokenization and encoding sentences to the form required by the BERT model
Classification model
Dataset preparation
Training loop and validation
Model sanity test results


Training and validation results

Question Answering

This BERT model, trained on SQuaD 1.1, is quite good for question answering tasks. SQuaD 1.1 contains over 100,000 question-answer pairs on 500+ articles. In SQuAD dataset, a single sample consists of a paragraph and a set questions. The goal is to find, for each question, a span of text in a paragraph that answers that question. Model performance is measured as the percentage of predictions that closely match any of the ground-truth answers.

BERT model is fine-tuned to perform this task in the following way:

  1. Context and the question are preprocessed and passed as inputs.
  2. Take the state of last hidden layer and feed it into the start token classifier. The start token classifier only has a single set of weights which it applies to every word. After taking the dot product between the output embeddings and the start weights, we apply the softmax activation to produce a probability distribution over all of the words. Whichever word has the highest probability of being the start token is the one that we pick.
  3. We repeat this process for the end token — we have a separate weight vector for this.
Class for preprocessing a single example from the SQuAD dataset
Preparing SQuAD dataset
Validation callback
Question answering model
Model sanity test results


Training and validation results

As you can see, both models showed performance close to that indicated in the original papers.

P.S. I will update this article and the related GitHub project with the BERT-Large model when RTX 3090 support is available in TensorFlow.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store