Understanding the GPT-2 Source Code Part 2
Hi! This is the next in the series of trying to understand the GPT-2’s source code and hopefully learn a thing or two. Part 1 can be found here. If there are any problems, unclear spots or feedback, please don’t hesitate to mention them in the comments!
In this part, I will go through the encoder.py and encode.py.
What is encoding?
One of the most important things to understand is that when you input text into a model, it cannot just use that text. The machine, before training, has no conception of what an “apple” or a “pear” is and how they might relate to each other. In fact, for the machine, it is outright confusing to be presented with the words “apple” or “pear”. It will rather like to see numbers like 1 and 2 represent them. And that is what encoding does! It converts words into numbers!
How does OpenAI do it?
First, let us look at encode.py. The contents are given as follows.
#!/usr/bin/env python3
# Usage:
# PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/output.npz
# PYTHONPATH=src ./train --dataset /path/to/output.npzimport argparse
import numpy as npimport encoder
from load_dataset import load_datasetparser = argparse.ArgumentParser(
description='Pre-encode text files into tokenized training set.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path')def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name)
print('Reading files')
chunks = load_dataset(enc, args.in_text, args.combine)
print('Writing', args.out_npz)
np.savez_compressed(args.out_npz, *chunks)if __name__ == '__main__':
main()
The encode.py takes 4 arguments as shown in the parser. I’m not sure why they did not use the fire library here so if anyone knows, please tell me!
The 4 arguments are
- model_name — Currently, there is only 117M and 345M which I do understand.
- combine — It is written that “Concatenate files with <|endoftext|> separator into chunks of this minimum size”. Which I do not understand now. Thus, I plan to go in further into the source code in order to find out what this parameter does specifically.
- in_text — The input .txt file
- out_npz — The output file in npz format.
The first line of interest that we see next is
enc = encoder.get_encoder(args.model_name)
Thus, let us look into the encoder.py to see what the get_encoder function does.
def get_encoder(model_name):
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
encoder = json.load(f)
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
The function does 3 things.
- Get encoder.json
- Get vocab.bpe and split at the new lines and ignore first and the last characters
- Initialize return the Encoder class which is initialized with encoder.json and vocab.bpe
The encoder.json and vocab.bpe were same for both the 117M model and the 345M model so the model name does not have that much importance. When you open encoder.json and see its contents, you see
{“!”: 0, “\””: 1, “#”: 2, “$”: 3, “%”: 4, “&”: 5, “‘“: 6, “(“: 7, “)”: 8, “*”: 9, “+”: 10, “,”: 11, “-”: 12,
and so on until
“\u0120Collider”: 50253, “\u0120informants”: 50254, “\u0120gazed”: 50255, “<|endoftext|>”: 50256}
Thus, it is pretty clear that this encoder.json represents the numbers that each word or symbol gets mapped to.
However, for vocab.bpe, what I saw when I opened it was
#version: 0.2
Ġ t
Ġ a
h e
i n
r e
o n
Ġt he
which I was not sure what it is about.
What does vocab.bpe do?
Apparently, it is a thing called byte pair encoding. According to Wikipedia, it is a compression technique where, to use the example from there, given a string
aaabdaaabac
since aa repeats more than once, we can replace it with an unused byte, Z it can be compressed as
ZabdZabac
Z=aa
Since ab repeats, it can be replaced with Y as
ZYdZYac
Y=ab
Z=aa
and so on until there are no repeating byte pairs. However, judging from the file, there seems to be at least a slight modification as characters, such as “h”, which I doubt is unused, is used to represent a single letter “e”, and not a pair of characters as the algorithm seems to suggest.
Since the source code for this is not quite available, I decided to search the web! The first thing I found was this paper. TLDR, it was basically about how byte pair encoding can be used to find meaning in new words.
To give an example in the paper, given the new word
“Abwasser|behandlungs|anlange” in German, if we use byte pair encoding, it can be segmented into 3 subwords ‘sewage water treatment plant’ while if we just encode it into a vector from the beginning, upon encountering it, there is no way to tell what it is about.
However, still, I am puzzled about how a sequence like
Ġ t
Ġ a
h e
i n
r e
o n
Ġt h
is plausible. Thus, I decided to go into implementation on the web which can be found here. Thanks, Rico Sennrich! I looked into the code and my understanding was that it was not that Ġ corresponded to t but Ġ t was the byte pair!
At the risk of being a bit boring, I think I’ll explain. If you are not interested, please go to the next section!
So, what I basically did was a search for any mention of the output file. The place I found a mention of it was in the learn_bpe function. There were two instances. The first was
outfile.write(‘#version: 0.2\n’)
And the second was
outfile.write(‘{0} {1}\n’.format(*most_frequent))
The version makes it quite evident that OpenAI used this very python file! Most frequent was basically taking the most frequent byte pairs in turn and appending writing them to the outfile.
Encoding dataset
The next line(which was not a comment) in encode.py was
chunks = load_dataset(enc, args.in_text, args.combine)
enc here is the Encoder instance returned previously. Now, let us look at the load_dataset function. The first part is
def load_dataset(enc, path, combine):
paths = []
if os.path.isfile(path):
# Simple file
paths.append(path)
elif os.path.isdir(path):
# Directory
for (dirpath, _, fnames) in os.walk(path):
for fname in fnames:
paths.append(os.path.join(dirpath, fname))
else:
# Assume glob
paths = glob.glob(path)
This basically appends the path of a text file or text files inside a directory to a list called paths. os.walk is a fancy function that walks over the files in a directory. Globs are basically files with wild card characters. For example, *.txt can much any form of text file such as a.txt, adfdj.txt and so on because * is a special wild card character. Thus, it is a glob. If I’m wrong please tell me in the comments!
token_chunks = []
raw_text = ''
for path in tqdm.tqdm(paths):
if path.endswith('.npz'):
# Pre-encoded
with np.load(path) as npz:
for item in npz.files:
token_chunks.append(npz[item])
else:
# Plain text
with open(path, 'r') as fp:
raw_text += fp.read()
if len(raw_text) >= combine:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
raw_text = ''
else:
raw_text += '<|endoftext|>'
if raw_text:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
return token_chunks
The first portion is.
if path.endswith('.npz'):
# Pre-encoded
with np.load(path) as npz:
for item in npz.files:
token_chunks.append(npz[item])
is for files which are already encoded. Basically, what it does is to overwrite the output file with this newly encoded file.
else:
# Plain text
with open(path, 'r') as fp:
raw_text += fp.read()
if len(raw_text) >= combine:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
raw_text = ''
else:
raw_text += '<|endoftext|>'
Here, I finally understood what the combine parameter meant. If a text is below the number of characters of combine, the text file is ignored. Now, let us look at what enc.encode(raw_text) does by looking at the method of Encoder encode.
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
TLDR;
The basic thing that is taking place here is
- For every pattern in the text return that pattern as a token
- Encode the token into utf-8 format and concatenate into a single string named token
- Extend the bpe tokens array to include byte pairs, the character pairs, in the token.
For those who are curious,
to understand this, let us go through the function from the beginning. The first line of interest is the following.
for token in re.findall(self.pat, text):
where self.pat is,
self.pat = re.compile(r”””’s|’t|’re|’ve|’m|’ll|’d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+”””)
Explanations for this regular expression
re.compile is just there to precompile the string. But let us now look at the string. This string is called a regular expression. And basically, what it does is to denote patterns texts. For example, if the pattern is composed of a single word “a” and if we try to find all the patterns in the text “I had a lot of pie”, there will only be two instances where the pattern will show up: inside “had” and “a”.
The first thing to notice about this string is that there is a lot of “|” in it. If you have experience with most programming languages except maybe python, I’m quite sure you will know that it means or.
The question mark corresponds to 0 or 1 repetitions of the previous character. So, for “ ?” it basically accepts any amount of spaces.
According to here,
\p{L}
matches a single code point in the category "letter".\p{N}
matches any kind of numeric character in any script.
However, it is important to note here that this is not in python’s re library and is only available in the regex library. So, the OpenAI team started off by writing
import regex as re
And as “+” means one or more, \p{L}+ can match any word and \p{N}+ can match any number
\s matches any Unicode whitespace character including \n and the like. While \S matches any non-whitespace character. ^ is for new lines and ?! means that it will much the pattern before it if the pattern in front of it does not appear.
So, basically, what it is doing is that it just segments words like “they’re” into “they” and “‘re” and so on.
Explanation for byte_encoder
token = ‘’.join(self.byte_encoder[b] for b in token.encode(‘utf-8’))
Here, the token pattern is encoded into utf-8 format. Then, in the for b in loop, it is changed into a number given by the function ord.
For example, ord(“a”.encode(‘utf-8’)) gives 97 while
[b for b in “a”.encode(‘utf-8’)]
gives 97 as well. The byte_encoder, as far as I understand, returns a slightly modified encoded unicode in certain cases. By slightly modified, I mean that the number jumps by 2⁸. self.byte_encoder is initialized with the function bytes_to_unicode. Which is the following
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
lru_cache, if I understand correctly, is a decorator function from a module called functools that caches the results of a function so that even if you call a function a second time, needless processing does not need to take place. But since this function only gets called once, I couldn’t think that much of a reason to have this decorator on. If anyone can explain please tell me!
As I think the code is quite self-explanatory, I’ll skip over it except mentioning that the chr() is the opposite of ord() in that when you give it a number, it gives back a character. Thus, the function returns a dictionary with keys from 1 to 2⁸ and the corresponding characters.
I couldn’t understand the comments that much but I think part of the reason for the 2⁸ shift is given by the following.
And avoids mapping to whitespace/control characters the bpe code barfs on.
But frankly, I do not understand it completely. However, the overall effect is that the tokens get converted into a suitable format.
An Explanation for Byte Pair Encoding Tokenization
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
First, let us look at the self.bpe function. It starts off as
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
get_pairs basically pairs every character pair there is and returns it.
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
Here, the first thing to note is self.bpe_ranks. This is defined by
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
in the init function. where bpe_merges is the array given by vocab.bpe. Thus, the most frequent byte pair is given the lowest number and the least frequent is given the highest. Thus, when we look at the following line,
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
it becomes evident that bigram denotes the most frequent character pair in the whole vocabulary in the dataset. The float(‘inf’) means that if the pair is not found in bpe_ranks, in the vocab.bpe, then return infinity. As infinity can’t be the minimum value, it is simply discarded.
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
For the .index function, the second argument denotes the start index of the range while the first argument denotes the value that is searched for in the tuple word. If i did not go past all the values in word that are given have the value first, to the new_word, word[i:j] is added.
If i did exceed the limit, then an error will be raised, and the except will catch it and all characters from i to the end will be added to new_word.
Now, more importantly,
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
Here, the pairs of words get added to new_word as well! And finally,
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
Here, we see that the loop will continue on until the length of the word becomes 1 as well as the fact that word gets new_word assigned to it. Now, on the next loop, if we remember back to what byte pair encoding is all about, it is now possible to find byte pairs or character pairs between 2 letter characters and fellow 2 letter characters or 1 letter character. In the next loop, even longer byte pairs can be made as it is always guaranteed that the most common byte pair is picked because of the following line!
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
Now, it is also quite possible that the length of the character does not become one. However, there is one other break statement in the while true loop and that is
if bigram not in self.bpe_ranks:
break
Thus, it can be safely said that if there are no more valid byte pairs in word, where word cannot be reduced to smaller tokens, the loop will terminate.
Then,
word = ‘ ‘.join(word)
self.cache[token] = word
return word
The word is joined by a space and it is returned.
Back to the encode function,
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(‘ ‘))
the bpe tokens are converted to numbers by the self.encoder, a former json file which was loaded.
Sorry if the above explanation became a bit complicated. I’m not sure if I fully understood all the code present here. Especially, the try/except block so if anyone can call me out on my mistakes or lack of clarity, please do.
Save to output
Back in load_dataset.py,
if raw_text:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
return token_chunks
the outputted tokens are turned into numpy arrays and appended to the tokens and finally, in encode.py
np.savez_compressed(args.out_npz, *chunks)
the outputs are saved as so.
Things we need to be careful about the data
As we can see from the process of encoding, the end_tokens are not added to the training data automatically. Thus, it is best advised that when using your own custom dataset to fine-tune the data, the end tokens are provided at the end of the text, especially if it’s a short text!
Next
In the next story, I’ll try to go in to sample.py and model.py! If you are interested, please read it here!