HuggingFace
Published in

HuggingFace

By Mahir Uysal

🦄 How to build a State-of-the-Art Conversational AI with Transfer Learning

Online demo of the pretrained model we’ll build in this tutorial at convai.huggingface.co. The “suggestions” (bottom) are also powered by the model putting itself in the shoes of the user.
  • How you can use Transfer Learning to build a State-of-the-Art dialog agent based on OpenAI GPT and GPT-2 Transformer language models,
  • How you can reproduce the model we used in the NeurIPS 2018 dialog competition ConvAI2 which won the automatic metrics track,
  • How we distilled 3k+ lines of competition code in less than 250 lines of commented training code (with distributed & FP16 options!), and
  • How you can train this model for less than $20 on a cloud instance, or just use our open-sourced pre-trained model.

An AI with a personality 🤠

  • start by pretraining a language model on a very large corpus of text to be able to generate long stretches of contiguous coherent text,
  • fine-tune this language model to adapt it to our end-task: dialog.

What would be a good pretrained model for our purpose?

🦄 OpenAI GPT and GPT-2 models

A decoder/causal Transformer attends to the left context to generate next words

👻 Adapting a language model to a dialog task

  • one or several persona sentences,
  • the history of the dialog with at least the last utterance from the user,
  • the tokens of the output sequence that have already been generated since we generate the output sequence word by word.
Input sequence: a concatenation of persona (blue), history (pink) and reply (green) with delimiters (light pink). Here we generate the word “you” to complete the reply.
  • Our transformer is color-blind! The delimiter tokens only give it a weak idea of which segment each word belongs to. For example, the word “NYC” is indicated in blue (persona) in our illustration but our model will have a hard time extracting this information from the delimiters alone: we should add more information about the segments.
  • Our transformer is position-blind! Attention is a symmetrical dot-product so we should add position information for each token.
Summing three types of inputs embeddings indicating words (grey), position (gradient) and segments (blue/pink/green)

👑 Multi-tasks losses

We will use a multi-task loss combining language modeling with a next-sentence prediction objective.

Multi-task training objective — the model is provided with two heads for language modeling prediction (orange) and next-sentence classification (blue)
  • Language modeling: we project the hidden-state on the word embedding matrix to get logits and apply a cross-entropy loss on the portion of the target corresponding to the gold reply (green labels on the above figure).
  • Next-sentence prediction: we pass the hidden-state of the last token (the end-of-sequence token) through a linear layer to get a score and apply a cross-entropy loss to classify correctly a gold answer among distractors.

🦊 Training on a dialog dataset

Organization of the JSON version of PERSONA-CHAT

👻 Talking with the Model — the Decoder

Generating a sentence word by word (source)
Left: Probability assigned to tokens generated by humans and beam search using GPT-2 (Note the strong variance in human text not reproduced by beam-search). Right: N-gram distributions in human and machine-generated texts (Note the complete separation between greedy/beam-search and sampling decoding methods).
Example using the interactive scripts with default settings — Bot personality: I read twenty books a year. I’m a stunt double as my second job. I only eat kosher. I was raised in a single parent household.

👻 Conclusion

  • the live demo is here and
  • the open-sourced code and pretrained models are here.

References:

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Thomas Wolf

Natural Language Processing, Deep learning and Computational Linguistics – Science Lead @Huggingface | thomwolf.io