Adapting BERT question answering for the medical domain
BERT, one of the breakthrough models in NLP from the last year, has changed the way we deal with the textual data. Though it has surpassed the state-of-art results for wide range of NLP tasks, some questions around its domain adaptation and extension for Question Answering task are still unexplored. This article aims to summarize the research done during the summer internship at Peltarion to adapt the BERT to the limited amount of domain-specific datasets for QA task with a long input context.
Among the many challenging NLP tasks where BERT has made major strides, Question Answering is one of them. The models performing well on this task have direct applications in different industries, for instance, one of the Peltarion’s clients, a big telecom company wants to use deep learning to optimize their RFP process. A model which can look into product document and highlight the relevant answers or evidence for questions asked in RFPs could significantly escalate the RFP process. The expected data from the client would be product documents, questions from RFPs and annotations as answers extracted from the respective product document. Before we could ask the client to do the expensive task of data collection, we started with BERT to build a proof-of-concept on the emrQA dataset. You can read more about BERT and question answering model for the SQUAD dataset in [1] and [2]. Given that you have a decent understanding of the BERT model, this blog would walk you through the research questions that guided our experiments, their results, and conclusions.
Research Questions
Earlier this year, a lot of papers [3,4] focussed on adapting BERT to different domains like science and biomedicine. The common approach was to either train the BERT from scratch with large domain-specific corpora like scientific papers and use domain-specific vocabulary or to finetune the language model of BERT with large domain-specific corpus while using the BERT’s original vocabulary. They further finetuned the adapted model on a task-specific dataset, and evaluate on sentence classification and NER which are relatively easier tasks than QA. Also, the datasets used in existing work usually have smaller context length, typically less than 512 tokens. BERT has a limit of 512 input tokens but the product-document from the client or the clinical note from the emrQA dataset is typically much bigger than BERT’s accepted input length. With these challenges on our way to building a QA model for the limited amount of domain-specific dataset, we defined the following research questions :
- What would be the approach to Question Answering task where input context or paragraph is n-times bigger than 512?
- How to adapt the BERT model for domain-specific QA dataset with a limited amount of domain-specific corpus (only product documents or only clinical notes)?
- Does replacing the placeholders in BERT’s vocabulary with the frequent domain-specific words help?
- How much training data is needed to achieve decent accuracy?
Dealing with a long context
BERT-QA model takes input in the following format: a “[CLS]” token, the tokenized question, a “[SEP]” token, tokens from the content of the document, and a final “[SEP]” token, limiting the total size of each instance to 512 tokens. As mentioned earlier, the length of clinical notes is roughly 5x of 512 and to deal with this issue, we modify the training data such that it can be used by the BERT model without modifying its architecture. The hack is to split the clinical notes into multiple sub-notes such that every sub-note when concatenated with input question, gives close to 512 tokens. After concatenating the input question with sub-note, there are two possibilities: 1) if sub-note contains the ground truth answer, label the answer with its start and end position in sub-note 2) if sub-note does not contain ground truth answer, label the start and end position as 0 (points to [CLS] token). During inference, we break the input clinical note in multiple sub-notes in a similar fashion as done during training and concatenate each part with the input question. Predict the start and end location for the answer on all sub-notes of a clinical note for the given question. It is done by computing span_score(c,s,e) for context c, start at location s and end at e.
The answer span (s,e) from the sub-part c of clinical note with the highest span_score(c,s,e) is predicted as the answer. This approach is partially inspired by [5].
Experiments and Results
After devising the approach to handle long context in QA, we did several experiments to explore the remaining research questions. For evaluation, the F1 score and EM_5 metrics are used. EM_5 is a modified version of exact match (EM) metrics and can be defined as how many times predicted start and end positions lies within the vicinity of +/-5 tokens of ground truth positions.
An experiment in the below plots can be seen as multiple training stages done in the temporal order as left to right, where the color of the arrow shows the type of training and two colors in one arrow suggest the additional modification done during that training stage. For all experiments, the model is initialized from the pre-trained BERT base uncased model’s weights.
All experiments can be broadly divided into three parts :
- Preliminary experiments
The preliminary experiments were done to find out the performance of the SQUAD finetuned BERT-QA model on domain-specific datasets like emrQA. It validated our intuition that the model would perform badly due to the lack of domain knowledge. It was also to find out how much domain-related knowledge can be acquired by doing task-driven finetuning with emrQA. Further, it was interesting to see if the model could adapt to the domain by training the language model using a limited amount of text data i.e. clinical notes from emrQA.
2. Effect of LM training
As observed from preliminary experiments, the task-related finetuning on the domain-specific dataset is the most promising step and training the LM on even a limited amount of domain text further pushed the performance. The next sentence prediction task was not used during LM training as two consecutive sentences in clinical notes would mostly maintain the structure and intent of the note even if they are not adjacent to each other. The first set of experiments was done towards improving the adaptation using LM training. One thing to try out was to replace 1k placeholders from the BERT’s original vocabulary by most frequent domain-specific words, to see if it would help the model learn those words better. Another idea was to augment the domain corpora by using the clinical notes from another dataset called CliCR. The low performance on augmentation of CliCR notes can be attributed to the difference in structure and content of clinical notes from CliCR (source) and the emrQA dataset (target). Below are the results for both the modifications:
3. Effect of finetuning
With the goal of making task-related finetuning better, I explored two directions. First, adding a classifier on top of BERT to shortlist relevant sub-parts from clinical notes which could have the potential answers for the input question. In simple words, the classifier would take the context (window of 512 tokens from a clinical note) and question to predict if the context is relevant to the given question or not. If yes, then start and end scores can be computed. Another question was if finetuning on large general domain QA dataset i.e. SQUAD before finetuning it on domain dataset emrQA, adds the value? Here are the results :
“we hypothesize that when the model is fine-tuned directly on the downstream tasks and uses only a very small number of randomly initialized additional parameters, the task-specific models can benefit from the larger, more expressive pre-trained representations even when downstream task data is very small. ” excerpt from BERT’s paper
4. Effect of the amount of training data
Annotating data is expensive and time-consuming, thus it is important to understand how much of annotated data is really needed. The steps from the above experiments were used to train the model on different size of training data while keeping the test data fixed. Below graph explains the results for the same:
Conclusions
- Models trained on the general domain dataset do not perform well on the domain-specific datasets.
- To adapt to the domain, task-driven finetuning with domain-specific QA dataset is one of the most important steps.
- Domain adaptation by LM training with limited data (with only available paragraphs or clinical notes from QA dataset) gives a marginal improvement in performance.
- Adding domain-specific words by replacing the placeholder from the original vocabulary helps the model to learn those words better.
- Finetuning the BERT-QA model with a large general domain QA dataset before finetuning on domain-specific QA dataset can prove helpful when the domain-specific dataset is limited.
Thanks to Anders for his guidance during the internship and valuable feedback for this article.
P.S. The code for all the approaches can be found on my Github account. All the experiments were done on the free TPU from Google Colaboratory
References
[1] https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html
[3] SciBERT: A Pretrained Language Model for Scientific Text, Beltagy, et al
[4] BioBERT: a pre-trained biomedical language representation model for biomedical text mining, Lee, et al
[5] A BERT Baseline for the Natural Questions, Alberti, et al