Mastering large language models – Part X: Transformer blocks

Today, we will start to look in greater detail at the transformer architecture. At the end of the day, a transformer is built out of individual layers, the so-called transformer blocks, stacked on top of each other, and our objective for today is to understand how these blocks are implemented.

Learning about transformer architectures can be confusing. If you read the original paper Attention is all you need by Vaswani et al. (and I highly recommend to actually take a look at the paper), you will, already on page 3, stumble upon a rather complicated diagram which is actually a bit intimidating if all of this is new to you. This paper talks about encoders and decoders, which is a language that we understand thanks to a previous post. On the other hand, you will hear statement that popular models like GPT-2 are in fact decoder-only models, without having a clear definition at hand what this actually means. To make things worse, decoder-only models are built out of PyTorch modules which, confusingly enough, are called torch.nn.TransformerEncoder. So before getting into too many details, let us first try to clean up that mess a bit.

The most general transformer models like the one described in the original paper or popular models like T5 are in fact encoder-decoder models. The general architecture is not so much different from the encoder-decoder models that we have already seen in the context of RNNs. There is a first network, the encoder, which translates a source sentence into a sequence of embeddings. Then, there is a second network that uses the output of the encoder as a context – typically called the memory in transformer-based architectures – to generate the source sequence. A crucial difference, though, is that the decoder reads the input of the encoder via an attention layer which gives the decoder network access to all time steps of the encoder at the same time. So very high-level, the picture is as follows.

Let us now dig a bit deeper. In fact, both, the encoder and the decoder, are built out of layers called transformer blocks. For the encoder, every transformer block has one input and one output. For the decoder, every transformer block has two inputs and one output. One input reads, via an attention mechanism, data from the encoder (more precisely, keys and values are taken from the encoder, while queries are taken from the decoder). The second input is taken from the previous transformer block of the encoder. So on the next level of detail, the diagram looks like this.

This explain the first (and most important) architectural difference between an encoder transformer block (torch.nn.TransformerEncoderLayer in PyTorch) and a decoder transformer block (torch.nn.TransformerDecoderLayer) – a decoder block has two inputs, an encoder block only one.

The second difference is the attention mask. Attention is actually applied at all inputs of a transformer block. For an encoder, masking is typically not done, so that the encoder can look at all token in the input. For the decoder, masking is also usually not done for the attention layer connecting the encoder and the decoder, so that the decoder has access to all time steps of the encoder output. For the second attention layer of a decoder, however, connected to the decoder input or the input of the previous layer, a causal self-attention mask is usually applied, i.e. the decoder is not allowed to peak ahead, as discussed in the last post. This difference, however, is not hardcoded into the PyTorch classes, as the attention mask in PyTorch is a parameter.

Finally, there are decoder-only transformers, and this is the class of transformer on which we will focus in the next few posts. These are transformers that are built from encoder blocks, each of which receives only one input, so there is no encoder. However, the attention mask used for the input is the one used for the decoder, so that we can apply teacher forcing to train such a model. Thus a decoder-only transformer receives a sequence of length L as input, encoded as a matrix of shape (L, D) where D is the embedding dimension, and produces an output of the same shape. If we add an embedding layer and an output layer, we can apply such a model for language generation very much in the same manner as an RNN, and this is the use case we will study first.

So let us talk about transformer blocks (more specifically, transformer blocks with only one input, i.e. encoder blocks in the PyTorch terminology) in more detail. Instead of jumping directly to the paper, let us take a different approach and look at the source code of the corresponding PyTorch module. The key lines in the forward method are

 x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
 x = self.norm2(x + self._ff_block(x))

Let us see what this does. First, we feed the input through a self-attention block. Note, however, that it is on us to specify a causal attention mask, as the model will not do this automatically.

Next, we apply a residual connection, i.e. we take the output of the self attention and add the original input to it. This helps to avoid vanishing gradients, as it allows a part of the gradient during backpropagation to flow directly back to the previous layer. Finally, we apply a layer normalization to the output, i.e. to the sum.

We then repeat the same pattern. This time, however, we run the result of the previous step through a feed-forward network before we again add the input to establish a residual connection and finally apply a normalization. Thus our full flow looks as follows.

So we repeat the pattern

\textrm{LayerNorm}(x + \textrm{Sublayer}(x))

twice, where first the sublayer is a self-attention layer and then the sublayer is a feed-forward network. This is, in fact, exactly the expression that we find in section 3.1 of the original paper.

Before we proceed, let us talk about dimensions. When working with transformers, it is common to present the input in the format (L, B, D), where L is the sequence length, B is the batch size and D is the dimension of the model. Each sublayer produces an output of the same shape, and so does the layer norm, so that the entire transformer block again produces an output of that shape.

In addition, a more careful look a the code shows that at the output of each sublayer, there is a dropout layer to avoid overfitting. As usual, the dropout is deactivated during inference and evaluation, but active during training.

So far, so good, but we are still missing an important point in our overall approach. To spot this, let us go ahead and quickly try out a transformer block in PyTorch, then permute the token in the input, i.e. change the orders of the input vectors, and run the new input through the transformer block again. Here is a short piece of code doing this (you can also have a look at the notebook that I prepared for this post.

# Model dimension
D = 8
# Length of encoder input
L = 4
#
# Create random input
#
X = torch.randn(L, D)
#
# and feed through an encoder block
#
block = torch.nn.TransformerEncoderLayer(d_model = D, nhead = 4)
block.eval()
Y = block(X).detach()
#
# Now permute the input, recompute
#
Xp = X.flip(dims = [0]).detach()
Yp = block(Xp).detach()
#
# Verify that Yp is simply the permutation of Y
#
print(f"{torch.allclose(Yp, Y.flip(dims = [0]))}")

We find that permuting the input to a transformer simply results in the same permutation applied to the output. This is not surprising after a closer look at the formula for the attention, but is not really matching our intuition of a good encoding of a sentence. After all, the two sentences

The dog chases the cat

and

The cat chases the dog

are permutations of each other, but have completely different meanings, so that from a good encoding, we would expect more than just a permutation as well. The problem becomes even more apparent if we now simulate the first step of processing that encoder output in a decoder, i.e. passing it through the attention layer connecting encoder and decoder. Recall that this attention layer uses keys and values from the encoder output, but the queries from the decoder input. Let us simulate this.

# length of target sequence, i.e. decoder input
T = 3 
queries = torch.randn(T, D)
attn = torch.nn.MultiheadAttention(embed_dim = D, num_heads = 4)
#
# Put into eval mode to avoid dropout
#
attn.eval()
#
# Values and keys are both the encoder output
#
out = attn(queries, Y, Y)[0].detach()
outp = attn(queries, Yp, Yp)[0].detach()
#
# Compare
#
print(f"{torch.allclose(out, outp)}")

This is even worse – the output is exactly the same, even after permuting the encoder input. That means that a transformer trying to translate the two sentences above would most likely produce the same target sentence, which we can also verify directly by using PyTorchs transformer module instead of its individual building blocks.

transformer = torch.nn.Transformer(d_model = D, nhead = 1)
transformer.eval()
tgt = torch.randn(T, D)
src = torch.randn(L, D)
src_permuted = src.flip(dims = [0])
out = transformer(src, tgt)
_out = transformer(src_permuted, tgt)
print(out)
print(_out)
print(f"{torch.allclose(out, _out)}")

So transformer models as discussed so far are not sensitive to the order of words in the input. To use them for NLP tasks, we therefore need an additional trick – positional embeddings – which we will discuss in the next post.