Mastering large language models – Part XII: byte-pair encoding

On our way to actually coding and training a transformer-based model, there is one last challenge that we have to master – encoding our input by splitting it in a meaningful way into subwords. Today, we will learn how to do this using an algorithm known as byte-pair encoding (BPE).

Let us first quickly recall what the problem is that subword encoding tries to solve. To a large extent, the recent success of large language models is due to the tremendous amount of data used to train them. GPT-3, for instance, has been trained on roughly 500 mio tokens (which, as we will see soon, are not exactly words) [3]. But even much smaller datasets tend to contain a large number of unique words. The WikiText103 dataset [4] for instance contains a bit more than 100 mio. words, which represent a vocabulary of 267.735 unique words. With standard word-level encoding, this would become a dimension in the input and output layers of our neural network. It is not hard to imagine that with 500 mio. token as input, this number would even be much larger and we clearly have a problem.

The traditional way to control the growth of the vocabulary is to introduce a minimum frequency. In other words, we could only include token in the vocabulary that occur more than a given number of times – the minimum frequency – in the training data and tune this number to realize a cap on the vocabulary size. This, however, has the disadvantage that during training and later inference, we will face unknown words. Of course we can reserve a special token for these unknown words, but still a translator would have to know how to handle them.

Character-level encoding solves the problem of unknown words, but introduces a new problem – the model first has to learn words before it can start to learn relations between words. In addition, if our model is capable of learning relations within a certain range L, it does of course make a difference whether the unit in which we measure this range is a character or a word.

This is the point where subword tokenization comes into play. Here, the general idea is to build up a vocabulary that consists of subword units, not full words – but also more than just individual characters. Of course, to make this efficient, we want to include those subwords in our vocabulary that occur often, and we want to include all individual characters. If we need to encode a word that consists of known subwords, we can use those, and in the worst case, we can still go down to the character level during encoding so that we will not face unknown words.

Byte-pair encoding that has first been applied to machine learning in [1] is one way to do this. Here, we build our vocabulary in two phases. In the first phase, we go through our corpus and include all characters that we find in our vocabulary. We then encode our data using this preliminary vocabulary.

In the second phase, we iterate through a process known as merge. During each merge, we go through our data and identify the pair of token in the vocabulary that occurs most frequently. We then include a new token in our vocabulary that is simply the concatenation of these two existing token and update our existing encoding. Thus every merge adds one more token to the vocabulary. We continue this process until the vocabulary has reached the desired size.

Note that in a real implementation, the process of counting token pairs to determine their frequency is usually done in two steps. First, we count how often an individual word occurs in the text and save this for later reference. Next, given a pair of token, we go through all words that contain this pair and then add up the frequencies of these words to determine the frequency of the token pair.

Let us look at an example to see how this works in practice. Suppose that the text we want to encode is

low, lower, newest, widest

In the first phase, we would build an initial vocabulary that consists of all characters that occur in this text. However, we want to make sure that we respect word boundaries, so we need to be able to distinguish between characters that appear within a word and those that appear at the end of the word. We do this by adding a special token </w> to a token that is located at the end of a word. In our case, this applies for instance to the character w that therefore gives rise to two entries in our vocabulary – w and w</w>. With this modification, our initial vocabulary is

'o', 'r</w>', 's', 'w</w>', 'n', 'i', 'e', 'l', 'd', 'w', 't</w>'

We now go through all possible combinations of two token and see how often the appear in the text. Here is what we would find in our example.

Token pairFrequency
l + o2
o + w</w>1
o + w1
w + e2
e + r</w>1
n + e1
e + w1
e + s2
s + t</w>2
w + i1
i + d1
d + e1

We now pick the pair of token (usually called a byte-pair, even though this is strictly speaking of course not correct when we use Unicode points) that occurs most frequently. In our case, several byte pairs occur twice, and we pick one of them, say w and e. We now add an additional token “we” to our vocabulary and re-encode our text. With this vocabulary, our text which previously was represented by the sequence of token

l, o, w</w>, l, o, w, e, r</w>, n, e, w, e, s, t</w>, w, i, d, e, s, t</w>

would now be encoded as

l, o, w</w>, l, o, we, r</w>, n, e, we, s, t</w>, w, i, d, e, s, t</w>

Note the new token, marked with bold face, that appears at the two positions where we previously had the combination of the token w and e. This concludes our first merge.

We can now continue and conduct the next merge. Each merge will add one token to our vocabulary, so that controlling the number of merges allows us to create a vocabulary of the desired size. Note that each merge results in two outputs – an updated vocabulary and a rule. In our case, the first merge resulted in the rule w, e ==> we.

To encode a piece of text that was not part of our training data when running the merges, we now simply replay these rules, i.e. we start with our text, break it down into characters, corresponding to the token in the initial vocabulary, and apply the rules that we have derived during the merges in the order in which the merges took place. Thus to encode text, we need access to the vocabulary and to the rules.

Let us now turn to implementing this in Python. The original paper [1] does already contain a few code snippets that make an implementation based on them rather easy. As usual, this blog post comes with a notebook that, essentially based on the code in [1], guides you through a simple implementation using the example discussed above. Here are a few code snippets to illustrate the process.

The first step is to build a dictionary that contains the frequencies of individual words, which we will later use to easily calculate the frequency of a byte pair. Note that the keys in this dictionary are the words in the input text, but already broken down into a sequence of token, separated by spaces, so that we need to update them as we merge and add new token.

def get_word_frequencies(pre_tokenized_text):
    counter = collections.Counter(pre_tokenized_text)
    word_frequencies = {" ".join(word) + "</w>" : frequency for word, frequency in counter.items() if len(word) > 0}
    return word_frequencies

As the keys in the dictionary are already pretokenized, we can now build our initial vocabulary based on these keys.

def build_vocabulary(word_frequencies):
    vocab = set()
    for word in word_frequencies.keys():
        for c in word.split():
            vocab.add(c)
    return vocab

Next, we need to be able to identify the pair of bytes that occurs most frequently. Again, Python collections can be utilized to do this – we simply go through all words, split them into symbols, iterate through all pairs and increase the count for this pair that we store in a dictionary.

def get_stats(word_frequencies):
  pairs = collections.defaultdict(int)
  for word, freq in word_frequencies.items():
    symbols = word.split()
    for i in range(len(symbols)-1):
      pairs[symbols[i],symbols[i+1]] += freq
  return pairs

Finally, we need a function that executes an actual merge. During each merge, we use regular expressions (more on the exact expression that we use here can be found in my notebook) to replace each occurence of the pair by the new token.

def do_merge(best_pair, word_frequencies, vocab):
    new_frequencies = dict()
    new_token = "".join(best_pair)
    pattern = r"(?<!\S)" + re.escape(" ".join(best_pair)) + r"(?!\S)"
    vocab.add(new_token)
    for word, freq in word_frequencies.items():
        new_word = re.sub(pattern, new_token, word)
        new_frequencies[new_word] = word_frequencies[word]
    return new_frequencies, vocab

We can now combine the functions above to run a merge. Essentially, during a merge, we need to collect the statistics to identify the most frequent pair, call our merge function to update the dictionary containing the word frequencies and append the rule that we have found to a list of rules that we will save later.

stats = get_stats(word_frequencies)
best_pair = max(stats, key=lambda x: (stats[x], x)) 
print(f"Best pair: {best_pair}")
word_frequencies, vocab = do_merge(best_pair, word_frequencies, vocab)
rules.append(best_pair)

This code, howevers, is awfully slow, as it basically repeats the process of counting once with every merge. The authors of the original paper also provide a reference implementation [2] that applies some tricks like caching and the use of indices to speed up the process significantly. In my repository for this series, I have assembled an implementation that follows this reference implementation to a large extent, but simplifies a few steps (in particular the incremental updates of the frequency counts) and is therefore hopefully a bit easier to read. The code consists of the main file BPE.py and a test script test bpe.py.

The test script can also be used to compare the output of our implementation with the reference implementation [3]. Let us quickly do this using a short snippet from “War and peace” that I have added to my repository.

#
# Clone the reference implementation
#
git clone https://github.com/rsennrich/subword-nmt.git
cd subword-nmt
#
# Get files from my repository
#
wget https://raw.githubusercontent.com/christianb93/MLLM/main/bpe/BPE.py
wget https://raw.githubusercontent.com/christianb93/MLLM/main/bpe/test_bpe.py
wget https://raw.githubusercontent.com/christianb93/MLLM/main/bpe/test.in
wget https://raw.githubusercontent.com/christianb93/MLLM/main/bpe/test.rules
#
# Run unit tests
# 
python3 test_bpe.py
#
# Run reference implementation
#
cat test.in | python3 subword_nmt/learn_bpe.py -s=50 > rules.ref
#
# Run our implementation
#
python3 test_bpe.py --infile=test.in --outfile=rules.dat
#
# Diff outputs
#
diff rules.dat rules.ref
diff rules.ref test.rules

The only difference that the diff should show you is a header line (containing the version number) that the reference implementation adds which we do not include in our output, all the remaining parts of the output files which contain the actual rules should be identical. The second diff verifies the reference file that is part of the repository and is used for the unit tests.

Over time, several other subword tokenization methods have been developed. Two popular methods are Googles wordpiece model ([6], [7]) and the unigram tokenizer [8]. While wordpiece is very similar to BPE and mainly differs in the way how the next byte pair for merging is selected, the unigram tokenizer starts with a very large vocabulary, containing all characters and the most commonly found substrings, and then iteratively removes items from the vocabulary based on a statistic model.

With BPE, we now have a tokenization method at our disposal which will allow us to decouple vocabulary size from the size of the training data, so that we can train a model even on large datasets without having to increase the model size. In the next post, we will put everything together and train a transformer-based model on the WikiText dataset.

References:

[1] R. Sennrich et al., Neural Machine Translation of Rare Words with Subword Units
[2] https://github.com/rsennrich/subword-nmt
[3] T. Brown et al., Language Models are Few-Shot Learners
[4] https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/
[5] K. Bostrom, G. Durrett, Byte Pair Encoding is Suboptimal for Language Model Pretraining
[6] Y. Wu et al., Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation
[7] M. Schuster, K. Nakajima, Japanese and Korean voice search
[8] T. Kudo, Subword Regularization: Improving Neural Network Translation Models with Multiple Subword Candidates

2 Comments

Leave a Comment