Mastering large language models – Part VI: sampling

Today, we will take a closer look at the process of using a trained LSTM or RNN to actually generate new content, i.e. to predict words. To set the scene, recall that the objective on which we have trained our network is to model the probability

P(w | w_1, \dots, w_n)

for each word in the vocabulary. More precisely, assume that our vocabulary has length V with elements v0, …, vV-1. Then the model is trained to predict for each i the probability that the next word is vi.

p_i = P(w = v_i| w_1, \dots, w_n)

As the network operates in time steps, it will create one corresponding vector p of probabilities with each time step, containing the probability distribution for the next word after having seen the input up to this point. Therefore the output of our model has shape (L, V), but we are only interested in the last output. In addition, recall that in practice, the softmax layer is usually not part of the model but contained in the loss function. Thus, to obtain the vector p of length V, we have to proceed as follows.

First, we take the sentence that we want to complete, the so-called prompt. We then tokenize this sentence and encode it as a tensor of shape L, where L is the number of token in the prompt. We feed this input vector x into the model and obtain the output (and the values of the hidden layer which we ignore). We then take the last element of the output and apply a softmax to obtain our probability distribution p. The corresponding code would look similar to this code snippet.

# x contains encoded prompt
f, _ = model(x)
 p = torch.softmax(f[-1], dim = 0)

which we have already seen in our toy model that was trained to complete a sequence of numbers. In this toy model, we have chosen the most straightforward approach to determine the next token from this probability distribution – take the index with the highest probability, i.e.

idx = torch.argmax(p).item()

This is again an index in our vocabulary. We can now look up the corresponding token in the vocabulary and append this token to our prompt. At this point, we have successfully extended the prompt by one generated token. We can now repeat the process to obtain a second token and so forth, until we have reached a certain specified minimum length. Note that in all steps but the first one, it is more efficient to feed the previously obtained hidden state back into the model, so that the model does not have to go through the entire sequence again. If we do this, however, we need to make sure that we feed only the last (just sampled) token as input, as the information on the previous part of the sentence is already encoded in the hidden state. Thus a complete function to sample could look as follows.

# Tokenize and encode prompt 
input_ids = [vocab[t] for t in tokenize(prompt)]
hidden = None
#
# Sample and append indices 
#
while (len(input_ids) < length):
    x = torch.tensor(input_ids, dtype = torch.long)
    x = x.to(device)
    #
    # Feed input ids into model
    #
    if hidden is None:
        f, hidden = model(x)
    else:
        f, hidden = model(x[-1].unsqueeze(dim = 0), hidden)
    #
    # f has shape (L, V) or (1,V) 
    # Take last element and apply softmax
    #
    p = torch.softmax(f[-1], dim = 0)
    #
    # Sample
    #
    idx = torch.argmax(p).item()
    #
    # and append   
    #
    input_ids.append(idx)

This will produce a list of indices which we still need to convert back into a string using the vocabulary.

The sampling method that we have applied here is sometimes called greedy sampling, because it greedily always selects the token with the highest probability weight. This is easy to implement (and fast), but has one major disadvantage – it is fully deterministic. Therefore the model easily gets stuck in loops during sampling and starts to repeat itself. This is in particular a problem if we use a short prompt, for instance “. ” to represent the start of a sentence. What we would actually want is a method that returns a reasonable sentence, but with some built-in randomness so that we do not always get the same sentence.

One way to do this is to actually draw a real sample from the probability distribution given by the vector p. PyTorch comes with a few helper classes to sample from various types of distributions, among them the categorical distribution which is actually nothing but a multinomial distribution. So instead of taking the argmax to determine the next index, we can use the line

idx = torch.distributions.categorical.Categorical(probs = p).sample()
idx = idx.item()

to draw an actual sample. Note that we first create a distribution object and then apply its sample method to perform the sampling. As the result is a tensor, we then use the item method to obtain a number that we can use as index into the vocabulary.

This sampling method is often applied with an additional parameter called the temperature. To understand this, let us discuss the impact of scaling the model output by some factor before applying the softmax. The softmax function is of course not linear, and due to the exponential function in the numerator, scaling by a large number will have a higher impact on those dimensions where the model output is already large. Thus scaling by a large number will increase existing spikes in the probability distribution. In the limit where the scaling factor tends to infinity, only the highest spike will survice and our sampling will be almost deterministic, so that we recover greedy search. Conversely, if the scaling factor is very small, the spikes will be softened, and eventually, in the limit when the scaling factor goes to zero, the resulting distribution will be the uniform distribution.

Traditionally, the parameter which is actually adjusted is the inverse of the scaling factor and is called the temperature. So the updated code including temperature looks like this.

# x contains encoded prompt
f, _ = model(x)
p = torch.softmax(f[-1] / temperature, dim = 0)

The discussion above shows that a small temperature value leads to a high scaling factor and therefore our sampling will become more and more deterministic, while a higher temperature will make the output more random (this is why the parameter is called the temperature, as this behaviour is what we also observe in statistical mechanics). Thus low temperatures are helpful if we want the model to stick as closely as possible to the training data, while higher temperatures make the model more creative. It is instructive to plot the probability distributions that different temperatures produce, I have done this in this notebook.

An extension of this sampling approach is known as top-k sampling (this method appears in Hierarchical Neural Story Generation by Fan et al., however, I am not sure whether this is really the first time this was proposed). In top-k sampling, we first pick the k indices with the highest probability weigths, were k is a parameter, rescale this to become a probability distribution again and sample from this modified distribution. The idea behind this is to avoid the tail distribution, i.e. to avoid that we accidentally sample very uncommon continuations, while still being more random than we are with greedy search. With PyTorch, this can be implemented as follows.

#
# Sort and remove all indices after the k-th index
#
_, indices = torch.sort(p, descending = True)
keep = indices[:k_val]
#
# Sample over the items that are left
#
_p = [p[i] for i in keep]
idx = torch.distributions.categorical.Categorical(probs = torch.tensor(_p)).sample()
idx = idx.item()
idx = keep[idx]        

Here we first use torch.sort to sort the probability distribution vector p in descending order, and then pick the k_val largest values, where k_val is the value of the k-parameter. So at this point, our array keep contains the indices that we want to sample from. We then collect the probabilities into a new probability vector and build a new multinomial distribution from this vector (which PyTorch will normalize automatically) from which we sample. As the output will be the position in the keep array, we still have to look up the actual index in this array.

Nucleus sampling or top-p sampling was proposed by Holtzman et al. in The curious case of neural text degeneration and extends this approach. Instead of using the k indices with the highest probability weights for a fixed value of k, we use those indices for which the total probability mass accounts for a certain minimum probability p_val, typically 0.9 or 0.95, i.e. we disregard the tail carrying the last 5% or 10% of the probability mass. The code is very similar, except that we first sum up the probabilities to determine the cut-off k and then proceed as for top-k sampling.

items , indices = torch.sort(p, descending = True)    
items = iter(items.tolist())
_sum = 0
_k = 0
while _sum <= p_val:
    _sum, _k =  _sum + next(items), _k + 1
keep = indices[:_k]
_p = [p[i] for i in keep]
idx = torch.distributions.categorical.Categorical(probs = torch.tensor(_p)).sample()
idx = idx.item()
idx = keep[idx]    

Of course, these methods can both be combined with a temperature parameter to control how creative the model can become. In practice, it is worth playing with different sampling methods and different values for the respective parameters (temperature, k, p) to see which combination gives the best results, depending on the use case.

We note that there is a family of sampling methods that we have not explained known as beam search, were instead of a candidate token, one tracks and scores candidates for entire sentences. Beam search is extensively described in the literature, for instance in Deep dive into deep learning section 10.8 or chapter 10 of Speech and Language Processing. Beam search can also be combined with sampling to obtain stochastic beam search. Finding the best method for sampling from large language models continues to be an active area of research. I encourage you to take a look at papers linked above, which provide a good overview of the various approaches to realize sampling and to measure the quality of the outcome.

This is a good point in time to look back at what we have discussed in the previous posts. We have learned how to tokenize a text, how to build a vocabulary and how to encode a text. We have then seen how various types of RNNs are implemented and trained, and finally we have looked at different methods to sample from the trained models. In the next post, we will put all this together and train a model on Tolstoys novel “War and peace”.

Leave a Comment

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s