Playing with Transformers
“Attention is all you need” is the paper that launched large language models. In this article, we’ll try write some high-level code to implement a transformer model that you can train yourself (even without an expensive GPU).
We’re going to write a transformer architecture for characters. The goal of this model is just to predict the next character. I’m pretty sure it’ll be terrible, but characters are easier than words. There’s fewer of them.
We’re not going to try and explain the paper in detail, instead we’re going to demonstrate that you can glue a few lines of Python together and experience the transformer model. This should give you enough code so you can start fiddling with parameters like sequence length, volume of training data and the training process and just see what happens.
The illustrated transformer describes the internals far better than I can, so I’d encourage you to go and read that.
Encoding a position
One of the key parts of a transformer model is that they can process data in parallel. Our first task is to build the input to go in. Imagine we have a stream of data; our job is to break it up into tokens and turn each token into a vector that encodes both its value and its position.
In the example below we’ve used the awesome embedding of alphabetical order and we’ve got a couple of dictionary structures for turning a character into a number and back again. That’s the value part of the equation.
vocab = sorted(set(text))
char_2_idx= {char: idx for idx, char in enumerate(vocab)}
idx_2_char = np.array(vocab)
How do we add positional information? Well, once we’ve converted a token into a vector, we augment that with positioning information like this using the formulas straight from the paper.
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; cos to odd indices
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
Why this arrangement with sin / cos? Think of it as a way of capturing the position in the set of any item. The future apart, the less impact they have. With a phrase like “peter picked a peck of pickled peppers” each p
has a different vector. We can visualize this - this is a heatmap of the vector encoding for the text given before. The main thing to notice is that each repeated character is encoded differently because it’s now got position embedded in it.
Building our model
Let’s build our model.
def build_model(vocab_size, maxlen, embed_dim, num_heads, ff_dim, num_layers):
inputs = tf.keras.Input(shape=(maxlen,))
# Standard embedding layer
x = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True)(inputs)
# Add positional encoding (you can also use a built-in positional embedding if you prefer)
pos_encoding = positional_encoding(maxlen, embed_dim)
x = x + pos_encoding[:, :maxlen, :]
# The TransformerEncoder block internally handles multi-head attention, skip connections,
# layer normalization, and the feedforward network.
for _ in range(num_layers):
x = keras_nlp.layers.TransformerEncoder(
num_heads=num_heads,
intermediate_dim=ff_dim,
dropout=0.1
)(x)
outputs = tf.keras.layers.Dense(vocab_size)(x)
return tf.keras.Model(inputs=inputs, outputs=outputs)
What does all of this mean?
vocab_size
- the number of unique tokens. In our case the number of unique characters. It’s about 150 odd.maxlen
- the maximum size of sequences the model will processembed_dim
- the size of the embeddings (64 in our case)num_heads
- the number of attention heads in the multi-head attention model (I picked 4)ff_dim
- the number of dimensions in the feed forward networks hidden layer (I picked 128)num_layers
- how many layers of transformer blocks do we use. (I picked 2).
We can put that all together and see how many parameters we get:
Preparing some text data
What shall we use for training? Well, since we’re just using characters, we’ll use Project Gutenberg and grab a whole bunch of text files.
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
def get_training_set(pathToFile: str):
# Load the text file
with open(pathToFile, "r", encoding="utf-8") as f:
text = f.read()
text = ' '.join(text.split())
# Get sorted unique characters in the text
vocab = sorted(set(text))
vocab_size = len(vocab)
# Create lookup dictionaries
CHAR_2_IDX = {char: idx for idx, char in enumerate(vocab)}
IDX_2_CHAR = np.array(vocab)
text_as_int = np.array([CHAR_2_IDX[c] for c in text])
seq_length = MAXLEN + 1 # extra character for the target
# Create a dataset of individual characters
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length, drop_remainder=True)
dataset = sequences.map(split_input_target)
return dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True), vocab_size, CHAR_2_IDX, IDX_2_CHAR
What does this mess do? Load a bunch of text, get rid of repeated whitespace and then create a training set of “some text” followed by the last character. It’s the job of training to optimize that!
Running training
What does training do? Throw the training set through the model and see what it predicts. If it’s wrong, nudge the parameters in the right direction (back propagation) and repeat. We continue training until the loss function (e.g. how accurate it is) becomes acceptable. If you want a more in-depth view, try this.
We use an optimizer known as adam
(adaptive momentum) and we run for a certain number of epochs. An epoch is a complete run through the training set.
model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.summary()
checkpoint_filepath = './checkpoint/foo.weights.h5'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='loss',
mode='min',
save_best_only=True
)
with (tf.device('/GPU:0')):
model.fit(trainingSet, epochs=EPOCHS, callbacks=[checkpoint_callback])
When this runs, you’ll see the loss function described after each epoch. As the data fits better, the loss function lowers. In this case, the model converged very quickly and showed only small progress after a handful of epochs.
Doing predictions
And for our final piece, we want to feed some tokens to the model and get its suggestion on what comes next.
def predict_next_tokens(model, input_text, char2idx, idx2char, num_tokens=10, temperature=1.0):
# Convert input text to integer indices (default to index 0 for unknown characters)
input_ids = [char2idx.get(c, 0) for c in input_text]
generated_ids = list(input_ids)
for _ in range(num_tokens):
# Ensure input length equals maxlen. If too short, left-pad with 0
if len(generated_ids) < MAXLEN:
padded_input = [0] * (MAXLEN - len(generated_ids)) + generated_ids
else:
padded_input = generated_ids[-MAXLEN:]
# Create tensor input with batch dimension
input_tensor = tf.expand_dims(padded_input, 0) # Shape: (1, maxlen)
# Get model predictions. The output shape is (1, maxlen, vocab_size)
predictions = model(input_tensor, training=False)
# Sample the next token (you can use greedy: tf.argmax(last_logits) if you prefer)
last_logits = predictions[0, -1, :]
scaled_logits = last_logits / temperature
predicted_id = tf.random.categorical(tf.expand_dims(scaled_logits, 0), num_samples=1)[0, 0].numpy()
generated_ids.append(predicted_id)
# Convert the sequence of token indices back to text
generated_text = ''.join([idx2char[i] for i in generated_ids])
return generated_text
How’s this work? Turn our text into a vector, feed it into the model, decode it and repeat until we’ve generated enough tokens.
Does it work?
Well, work is a funny word.
- Does it work in the sense of being remotely useful? No.
- Does it work in the sense of predicting the next character any better than you could do with just about any other method? No.
- Does it serve as a handy way to heat your house if you put in the wrong numbers? Yes.
- Does it allow me to play with a transformer model? Yes, and that’s all that matters 🙂
I trained it on the complete works of Shakespeare. Here’s some examples:
Hamlet, beet the she hous live buls, Afteeve, Crid To she. Dout a will layitio’s be as is wists beffer the
Macbether them of thy sook fto the vombect folk. Coulst, I shall thy comperved to his uport py won stole
Einstein? Veelf Wou ame, it best gent of Trom I your the we you welm. You tart my bll live raty evame ever m
What next?
I’ve got a system to play with now which is super cool! I can run through a training batch of 20MB of text in about 5 minutes on my GPU, which means I can experiment, change the shapes of the models, change the parameters and see what happens. Hopefully this means I can put some of what I intended to read in Designing ML Systems and AI Engineering into practice!