Deploying GPT-2 Models in Custom Applications

Robert Coleman
7 min readNov 10, 2019

--

Ever since OpenAI showed off their GPT-2 language model earlier this year, the machine learning community has been abuzz with demonstrations of the model’s potential. Where other language models in the past have struggled with issues of coherence and long-term dependencies over passages of text, GPT-2 is capable of generating convincing output throughout multiple paragraphs.

Until recently, OpenAI’s fully trained model was kept under wraps, but lighter-weight versions trained on more limited datasets have been available for some time, making it easy to generate fun examples in auto-completion and Q&A contexts.

The Issue of Integration

GPT-2 training and sampling tutorials are a dime a dozen, but less has been written about how you might actually incorporate a trained model into your own applications or websites. In this post I’ll walk through a sample integration with a Slackbot app (see my previous post for details on how to create your own) which will provide an easy interface for interacting with the model. I’ll be sampling a few different datasets as demos, including OkCupid dating profiles, rap lyrics, Trump speeches, and Ruby source code, but you can supply any training data you want for your own application.

Creating a Training Environment

First off, there are a handful of prerequisites and caveats for any deep-learning project involving sufficiently large data. To keep training times practical, you’ll likely want to make use of GPU acceleration — if you don’t have experience getting set up with CUDA for an ML framework like PyTorch or TensorFlow there are many tutorials available online, but that won’t be the focus of this post.

I will, however, specify the package versions I tested with, along with my CUDA and cuDNN versions, since getting the right combinations working there can be a bit tricky. These are listed in the “requirements.txt” file of the github repo accompanying this writeup. I use Anaconda to manage my environments, but any Python virtual env tool should work.

In terms of project structure, I’m keeping the source and util files for the GPT-2 model in their own module, and I’m using separate directories for models, checkpoints, and samples, like so:

The Training Strategy

In order to get the best results on your custom dataset, which is most likely much, much smaller than the data OpenAI trained their model on, you can implement a shortcut known as fine-tuning. Fine-tuning is a technique that takes a pretrained model and updates the weights of that model via iterative training on a new dataset, rather than initializing the weights for the new model from scratch; this usually enables gains in both required training time and performance on new data. In other words, fine-tuning can produce high-quality results from much less input data, allowing you to eke out realistic generative models from only a few MB of training data.

A good starting point for the pretrained model in this case would be the 124 million parameter model open-sourced by OpenAI in February 2019 (Note: due to an early naming error this model is sometimes referred to as the 117M parameter model). The model can be downloaded here or via the “download_model.py” script in the accompanying github repo.

Fortunately, once you’ve downloaded the pretrained model, the heavy-lifting is done; all you need to supply is your own raw text to use for fine-tuning. Depending on your application, this could be as simple as a single file containing unstructured text data, or it could be a more specialized dataset containing question/answer pairs, stanzas of poetry, code blocks, ascii art, etc. The input format shouldn’t substantially change the training methodology however, since GPT-2 is capable of learning the long-term dependencies required for structured text representations.

When you’re ready to start training, provide a path to your data and a name for your custom model and kick off the fine-tuning process:

$ python train.py --dataset /path/to/dataset.txt --run_name <your_custom_dataset_name>

While the model is training, it will intermittently generate samples and save checkpoints so that you can monitor progress and resume training if interrupted. These outputs will show up in the “samples/” and “checkpoints/” directories respectively.

Testing Your Model

Once you have some checkpoints populated in your new model directory, you can generate samples from that model using the “generate_samples.py” script:

$ python generate_sample.py --dataset /path/to/dataset.txt --run_name <your_custom_model_name>

This will print the sample output to the console, but the script should be easy to modify to write to disk if preferred. The script also accepts an optional “sample_length” parameter if you’d like to produce shorter or longer samples.

Adding Pipelines to Your Application

Now that you have your custom model working, you’ll need to make a few different pieces available to the application you’d like to integrate it with:

  1. The checkpoint directory for your custom model
  2. The model hyperparameters, which are included as a JSON file with the downloaded pretrained model)
  3. Prime text, for the model to use as a starting point when predicting the sequence that follows it
  4. An encoder, to convert the raw primetext string into the byte-pair-encoded format required by the model, and then to subsequently decode the BPE output back to a human-readable string

Putting it altogether, a fully-connected sampling function might look something like the following code block (omitting imports, helper functions, etc):

def generate_samples(run_name, sample_len=250, prime_text=None):
enc = get_encoder('117M')
hparams = default_hparams()
with open(os.path.join(MODEL_DIR, '117M', 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as sess:
context = tf.placeholder(tf.int32, [1, None])

tf_sample = sample_sequence(
hparams=hparams,
length=sample_len,
context=context,
batch_size=1,
temperature=1.0,
top_k=40
)

train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

saver = tf.train.Saver(
var_list=train_vars,
max_to_keep=5,
keep_checkpoint_every_n_hours=2)
sess.run(tf.global_variables_initializer())

ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, run_name))

# restore saved weights
saver.restore(sess, ckpt)

# generate sample from model
chunks = load_dataset(enc, DATASET_DIR.format(run_name, run_name), 50000)
data_sampler = Sampler(chunks)

if prime_text:
context_tokens = enc.encode(prime_text)
else:
context_tokens = data_sampler.sample(1)

out = sess.run(tf_sample, feed_dict={context: [context_tokens]})
text = enc.decode(out[0])
sess.close()

# strip prime text from response (if desired)
if prime_text and text.startswith(prime_text):
text = text.split(prime_text)[1]

return text

This function can then be referenced elsewhere in the application with the optional length and primetext arguments supplied…

generate_samples(“<custom_model_name>”, sample_len=100, prime_text=command)

…which will return the decoded text output.

Talking to Your Models

Finally, the fun part — chatting with the models. Here are a few random samples using the datasets I chose for this post.

First, a cringe-worthy dating profile learned from the OkCupid profiles dataset (27.7 MB, 3k training steps):

Next, some freestyling based on the rap lyrics dataset (80.9 MB, 5k training steps):

This particular model has a bit of a fresh mouth, so I had to do some censoring.

Then an obvious choice, a model trained on Trump’s public speeches (13.7 MB, 6k training steps):

And then last but not least, a more unique demonstration of GPT-2’s ability to learn structure in data — a model trained on the Ruby language’s source code downloaded from github (33.4 MB, 2k training steps):

Improvements, Scaling, Optimization

This post described a very quick and dirty way to get a model up and running in a simple app, but there’s tons of low-hanging fruit to optimize how models are served up in your own projects.

First and foremost, speed — in this demo there’s a lot of repetition of object-loading (encoders, parameters, etc) that could likely be cached so that reusing a model doesn’t start the process from scratch.

Then, there are many opportunities to add custom logic and heuristics for controlling the prime text that’s sent to a model, and how its response incorporates that prime text. This could make the interactions themselves much smoother.

It also might make sense to create more distinct boundaries between your app and the service that handles the model’s input and output; I’ve kept the scripts packaged together for simplicity, but in a more complex project that might lead to difficulties later on.

Finally, in a frequently-used model, you might want to turn new inputs from users into an online learning process, where the model learns continuously from new data. This would require some more substantial modifications, but could provide worthwhile gains for cases where additional training data is valuable.

--

--