Fastai integration with BERT: Multi-label text classification identifying toxicity in texts

Photo by Jules D. on Unsplash

PROLOGUE

DATA

LIBRARIES AND MAJOR DEPENDENCIES

INTEGRATION TECHNIQUES

from pytorch_pretrained_bert import BertTokenizerbert_tok = BertTokenizer.from_pretrained(
“bert-base-uncased”,
)
class FastAiBertTokenizer(BaseTokenizer):
“””Wrapper around BertTokenizer to be compatible with fast.ai”””
def __init__(self, tokenizer: BertTokenizer, max_seq_len: int=128, **kwargs):
self._pretrained_tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __call__(self, *args, **kwargs):
return self
def tokenizer(self, t:str) -> List[str]:
“””Limits the maximum sequence length”””
return [“[CLS]”] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len — 2] + [“[SEP]”]
fastai_bert_vocab = Vocab(list(bert_tok.vocab.keys()))
fastai_tokenizer = Tokenizer(tok_func=FastAiBertTokenizer(bert_tok, max_seq_len=256), pre_rules=[], post_rules=[])
label_cols = [“toxic”, “severe_toxic”, “obscene”, “threat”, “insult”, “identity_hate”]databunch_1 = TextDataBunch.from_df(“.”, train, val, 
tokenizer=fastai_tokenizer,
vocab=fastai_bert_vocab,
include_bos=False,
include_eos=False,
text_cols=”comment_text”,
label_cols=label_cols,
bs=32,
collate_fn=partial(pad_collate, pad_first=False, pad_idx=0),
)
def bert_clas_split(self) -> List[nn.Module]:

bert = model.bert
embedder = bert.embeddings
pooler = bert.pooler
encoder = bert.encoder
classifier = [model.dropout, model.classifier]
n = len(encoder.layer)//3
print(n)
groups = [[embedder], list(encoder.layer[:n]), list(encoder.layer[n+1:2*n]), list(encoder.layer[(2*n)+1:]), [pooler], classifier]
return groups
from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification, BertForNextSentencePrediction, BertForMaskedLMbert_model_class = BertForSequenceClassification.from_pretrained(‘bert-base-uncased’, num_labels=6)model = bert_model_class
from fastai.callbacks import *learner = Learner(
databunch_1, model,
loss_func=loss_func, model_dir=’/temp/model’, metrics=acc_02,
)
x = bert_clas_split(model)learner.split([x[0], x[1], x[2], x[3], x[5]])

MODEL PERFORMANCE

BERT’s performance on multi-label classification task
Model is performing really well! It correctly identified that in first text, there is no abusive slang or threats but in next sentence, it identified, toxicity, obscenity and insult
Fastai’s performance on multi-label classification task
Pretty neat! It predicted one additional category than what was predicted by BERT

CONCLUSION