RPC — A New Way to Build Language Models

Relevant Precedence Compression

Pedro Magalhães
6 min read3 days ago

Note: This article assumes a foundational understanding of language models and their training processes. Detailed explanations of these concepts are not provided.

The primary goal of this work is to create a very small language model (SLM) with fewer than 5 million parameters that can generate reasonable text completions. Unlike traditional methods that use a neural network to predict the next token distribution, RPC employs a neural network to compress the prompt into a vector. This vector is then used to search for the next token in a vector database.

This concept is similar to RAG encoders, where a query embedding is used to search for relevant contexts. Here, the encoder generates an embedding that encodes the initial prompt, and we search for the most similar one in the database to retrieve the following token.

Below is a visual explanation of the training and inference pipelines:

Training Data

I wont be detailing the training data in depth but it is a collection from several well-known datasets. This includes SQuAD, COQA, Amazon Topical Chat, Facebook Empathetic Dialogues, and the test split of AllenAI SODA. These datasets were chosen to provide a mixture of question-answer pairs and conversational content, ensuring the model is exposed to a variety of language patterns and topics. Altogether, this combination amounts to a total of 176 MB of text data

Training Pipeline

The architecture of the encoder is irrelevant as long as it produces an embedding that accurately encodes the prompt to predict the next token. Here’s a simple architecture used in this experiment where the encoder alone has 4.3 million parameters:

input_size = 512
embed_size = 128
vocab_size = len(tokenizer.get_vocab().keys()) + 1

# Encoder
inputs_enc = Input(shape=(input_size, ), dtype=tf.int32)
emb_layer = Embedding(vocab_size, embed_size)
pos_layer = keras_nlp.layers.PositionEmbedding(input_size)

x = LayerNormalization()(emb_layer(inputs_enc))
pos = pos_layer(x)

b = 4
for _ in range(b):
x += b**-0.5 * LayerNormalization()(Attention()(x, pos))

encoder = keras.Model(inputs=inputs_enc, outputs=x)

# Decoder
inputs = Input(shape=(input_size, ), dtype=tf.int32)
x = encoder(inputs)

b = 4
for _ in range(b):
x1 = Dense(embed_size, activation="gelu")(x)
x1 = Dense(embed_size, activation="gelu")(x1)
x += b**-0.5 * LayerNormalization()(x1)

lm_head = Lambda(lambda x: tf.nn.softmax(tf.matmul(x, emb_layer.embeddings, transpose_b=True), axis=-1))
x = lm_head(x)

model = keras.Model(inputs=inputs, outputs=x)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=tokenizer.pad_token_id),
optimizer=keras.optimizers.AdamW(learning_rate=0.001),
metrics=["accuracy", keras_nlp.metrics.Perplexity(mask_token_id=tokenizer.pad_token_id)],
)

encoder.summary()

The Attention layer is an implementation of attention without attention heads with learned relative positional encodings. Here’s the code:

class Attention(keras.layers.Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)

def build(self, input_shape):
self.embed_size = input_shape[-1]
self.mask = tf.where(tf.linalg.band_part(tf.ones((input_shape[-2], input_shape[-2])), -1, 0) == 1.0, 0.0, float("-inf"))
self.range_do = -tf.range(input_shape[-2])-1
self.range_undo = tf.range(input_shape[-2])+1
self.Q = self.add_weight(name='kernelQ',
shape=(input_shape[-1], input_shape[-1]),
initializer='uniform',
trainable=True)
self.K = self.add_weight(name='kernelK',
shape=(input_shape[-1], input_shape[-1]),
initializer='uniform',
trainable=True)
self.V = self.add_weight(name='kernelV',
shape=(input_shape[-1], input_shape[-1]),
initializer='uniform',
trainable=True)
super(Attention, self).build(input_shape)

def roll_embeddings(self, tensor, shift_values):
batch_size, time_size, embed_size = tensor.shape
if batch_size is None: return tensor
shift_matrix = tf.reshape(shift_values, (1, -1, 1))
shift_matrix = tf.tile(shift_matrix, [batch_size, 1, embed_size])
indices = tf.range(embed_size)
indices_matrix = tf.tile(indices, [batch_size * time_size])
indices_matrix = tf.reshape(indices_matrix, (batch_size, time_size, embed_size))
new_indices = (indices_matrix + shift_matrix) % embed_size
rolled_tensor = tf.gather(tensor, new_indices, batch_dims=2)
return rolled_tensor

def call(self, x, pos):
q = x @ self.Q
k = x @ self.K
v = x @ self.V
atti = tf.matmul(q, k, transpose_b=True)
attp = tf.matmul(q, pos, transpose_b=True)
attp = self.roll_embeddings(attp, self.range_do)
att = atti + attp
att = tf.nn.softmax((att / math.sqrt(self.embed_size)) + self.mask, axis=-1)
outi = att @ v
attp = self.roll_embeddings(att, self.range_undo)
outp = attp @ pos
out = outi + outp
return out

To train the model, we use tokenized texts similar to any typical language model training process. In this case, a sample weight is created for each token. If the token is a noun, number, adjective, adverb, or proper noun and it has appeared in the input prompt, the sample weight is 1.0. Otherwise, it is 0.6. This isn’t strictly necessary but helps the model copy tokens from the past with less training time, improving accuracy in answering factual questions and tracking conversation topics.

for i in range(70):
x, w = get_train_session(4096*8, input_size+1)
if i > 10 and i < 25:
w = tf.where(w < 0.9, 0.05, 1.0)
model.fit(x=x[:, :-1], y=x[:, 1:], shuffle=True, epochs=1, batch_size=16, sample_weight=w[:, 1:])
model.save("model_slm.hdf5")

Inference Pipeline

The first step to start using the SLM is to extract the encoder from the enssemble:

model = keras.models.load_model(
"model_slm.hdf5",
custom_objects={
"Attention" : Attention,
"masked_accuracy" : masked_accuracy,
},
safe_mode=False,
)
# Extract encoder
encoder = model.layers[1]

Next, we create a function to vectorize a list of texts, which will be useful for generating the vector database and for inference:

def vectorize_texts(all_texts, batch_size=128, pad=tokenizer.pad_token_id):
vects = []
for i in range(len(all_texts) // batch_size + 1):
texts = all_texts[i*batch_size:i*batch_size+batch_size]
toks = [text + ([pad] * (input_size - len(text))) for text in texts]
toks = tf.constant(toks, shape=(len(toks), input_size))
vect = encoder(toks)
for v, t in zip(vect, texts):
vects.append(v[:len(t), :])
return tf.concat(vects, axis=0)

Now we encode all the texts in our dataset and store the embeddings and the corresponding tokens:

all_toks = []
prompt_embeds = []

batch_size = 128
batch = []
cur_batch_size = 0

for j, text in enumerate(data):
text_size = min(len(text), input_size+1)
all_toks += text[1:text_size]
prompts = text[:text_size-1]

batch.append(prompts)
cur_batch_size += 1

if cur_batch_size >= batch_size:
prompt_embeds.append(vectorize_texts(batch))
cur_batch_size = 0
batch = []
print(j)

prompt_embeds = np.vstack(prompt_embeds).reshape((sum([len(v) for v in prompt_embeds]), embed_size))

After obtaining all the embeddings in prompt_embeds and the corresponding tokens in all_toks, we create the vector database. This example uses the faiss library to create a flat vector index. You can also create an HNSW index for faster search speed at the cost of slower build time.

import faiss

index = faiss.IndexFlat(embed_size) # IndexHNSWFlat(embed_size, 32)
index.add(prompt_embeds)

Finally, we can use the encoder and vector index to generate text. Here are two example (not cherry picked) texts used:

text1 = """<s>Peter: Hello there!\n"""

text2 = """<s>The dog is red and has five legs.
User: What color is the dog?
Assistant: red
User: How many legs does the dog have?
Assistant:"""

k = 10
temp = 0.01
text = text2
size = 1

enc_text = tokenizer.encode(text, add_special_tokens=False)
text = tokenizer.decode(enc_text)
print(text, end="")

for t in range(size):
xq = vectorize_texts([enc_text])[-1]
xq = np.array(xq).reshape((1, embed_size))
D, I = index.search(xq, k)
toks = [all_toks[i] for i in I[0]]
dists_sft = tf.nn.softmax(-D[0] / temp, axis=-1)
c = tf.random.categorical(tf.math.log([dists_sft]), num_samples=1)[0][0]
tok = toks[c]

enc_text += [tok]
new_text = tokenizer.decode(enc_text)

print(new_text[len(text):], end="")

text = new_text

Results

Text 1 example output:

<s> Peter: Hello there!
Mia: Hello there, do you follow baseball or MLB much?
Peter: I do, although I haven't followed it much the past couple of years.
Mia: I am the opposite, I have been following more closely these last few, especially this last season as my team almost made the world series.
Peter: Their arrow means they cover everything from the old days.
Mia: Nice, yeah I like it. Did you know that the cubs were the first team to win back to back World Series.
Peter: Women were not allowed to wear a baseball uniform as to be able to play for their teams if the need arises
Mia: Maybe it is his

Text 2 example output:

<s>The dog is red and has five legs.
User: What color is the dog?
Assistant: red
User: How many legs does the dog have?
Assistant: five

Some Notes

If you want the full code you can find it on GitHub.
I think the idea of using encoders and vector search for classification problems in general is interesting and should be further explored.

I encourage the community to experiment with this approach and try to scale it up. It would be valuable to see how it performs when scaled in terms of data volume, parameter count, and the size of the vector database. If you have any questions or want to know more about this small experiment, please leave a comment or message me on LinkedIn.

--

--