HuggingFace
Published in

HuggingFace

By Mahir Uysal

šŸ¦„ How to build a State-of-the-Art Conversational AI with Transfer Learning

Online demo of the pretrained model we’ll build in this tutorial at convai.huggingface.co. The ā€œsuggestionsā€ (bottom) are also powered by the model putting itself in the shoes of the user.
  • How you can use Transfer Learning to build a State-of-the-Art dialog agent based on OpenAI GPT and GPT-2 Transformer language models,
  • How you can reproduce the model we used in the NeurIPS 2018 dialog competition ConvAI2 which won the automatic metrics track,
  • How we distilled 3k+ lines of competition code in less than 250 lines of commented training code (with distributed & FP16 options!), and
  • How you can train this model for less than $20 on a cloud instance, or just use our open-sourced pre-trained model.

An AI with a personality 🤠

  • start by pretraining a language model on a very large corpus of text to be able to generate long stretches of contiguous coherent text,
  • fine-tune this language model to adapt it to our end-task: dialog.

What would be a good pretrained model for our purpose?

šŸ¦„ OpenAI GPT and GPT-2 models

A decoder/causal Transformer attends to the left context to generate next words

šŸ‘» Adapting a language model to a dialog task

  • one or several persona sentences,
  • the history of the dialog with at least the last utterance from the user,
  • the tokens of the output sequence that have already been generated since we generate the output sequence word by word.
Input sequence: a concatenation of persona (blue), history (pink) and reply (green) with delimiters (light pink). Here we generate the word ā€œyouā€ to complete the reply.
  • Our transformer is color-blind! The delimiter tokens only give it a weak idea of which segment each word belongs to. For example, the word ā€œNYCā€ is indicated in blue (persona) in our illustration but our model will have a hard time extracting this information from the delimiters alone: we should add more information about the segments.
  • Our transformer is position-blind! Attention is a symmetrical dot-product so we should add position information for each token.
Summing three types of inputs embeddings indicating words (grey), position (gradient) and segments (blue/pink/green)

šŸ‘‘ Multi-tasks losses

We will use a multi-task loss combining language modeling with a next-sentence prediction objective.

Multi-task training objective — the model is provided with two heads for language modeling prediction (orange) and next-sentence classification (blue)
  • Language modeling: we project the hidden-state on the word embedding matrix to get logits and apply a cross-entropy loss on the portion of the target corresponding to the gold reply (green labels on the above figure).
  • Next-sentence prediction: we pass the hidden-state of the last token (the end-of-sequence token) through a linear layer to get a score and apply a cross-entropy loss to classify correctly a gold answer among distractors.

🦊 Training on a dialog dataset

Organization of the JSON version of PERSONA-CHAT

šŸ‘» Talking with the Model — the Decoder

Generating a sentence word by word (source)
Left: Probability assigned to tokens generated by humans and beam search using GPT-2 (Note the strong variance in human text not reproduced by beam-search). Right: N-gram distributions in human and machine-generated texts (Note the complete separation between greedy/beam-search and sampling decoding methods).
Example using the interactive scripts with default settings — Bot personality: I read twenty books a year. I’m a stunt double as my second job. I only eat kosher. I was raised in a single parent household.

šŸ‘» Conclusion

  • the live demo is here and
  • the open-sourced code and pretrained models are here.

References:

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Thomas Wolf

Natural Language Processing, Deep learning and Computational Linguistics – Science Lead @Huggingface | thomwolf.io