Show Me The Code: Transformers
Published:
This blog posts explains the Transformer architecture, along with code implementations.
<!DOCTYPE html>
Show Me The Code: Transformers¶
This notebook is a walk-through implementation of the Transformer model.
import torch
import numpy as np
from IPython.display import Image, display
Embedding¶
We first implement the embedding introduced in the original transformer paper [1]. Given an input seqence of tokens $(x_1, x_2, \ldots, x_N)$ where each token is from a vocabulary of size $V$, the embedding is a combination of the token embedding and the positional embedding:
$$ Embedding(x_n, n) = TokenEmbedding(x_n) + PositionalEmbedding(n) $$Token Embedding¶
Denote $d_{model}$ as the model dimension in a transformer. Each layer's input and output are both $d_{model}$-dimensional, same as the embedding dimension. The token embedding can be retrieved from a lookup table of size $\mathbb{R}^{V \times d_{model}}$. In the transformer paper, the token embeddings are multiplied with $\sqrt{d_{model}}$.
Positional Embedding¶
The positional embedding has a sinusoidal formulation:
$$ PositionalEmbedding(n)_{2i} = \sin\left(\frac{n}{10000^{2i/d_{model}}}\right), i=1,...,\frac{d_{model}}{2} $$$$ PositionalEmbedding(n)_{2i+1} = \cos\left(\frac{n}{10000^{2i/d_{model}}}\right), i=1,...,\frac{d_{model}}{2} $$In this formulation, each dimension corresponds to a sinusoidal signal with a different frequency, captuing short-term and long-term relationships.
Dropout¶
Finally, we apply a dropout to the output embedding for additional regularization.
Computational Cost¶
Given an input sequence of length $N$, the computational cost of the embedding layer is $\mathcal{O}(N d_{model})$. The memory cost is $\mathcal{O}(V \times d_{model} + N \times d_{model})$, where the vocabulary size usually dominates the sequence length.
class Embedding(torch.nn.Module):
def __init__(self, d_model, vocab_size, max_seq_len=4096, dropout=0.1):
super(Embedding, self).__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.embedding = torch.nn.Embedding(vocab_size, d_model)
self.position_embedding = SinusoidalPositionalEmbedding(d_model, max_seq_len)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x):
x = self.embedding(x) * np.sqrt(self.d_model)
x = x + self.position_embedding(x)
return self.dropout(x)
class SinusoidalPositionalEmbedding(torch.nn.Module):
def __init__(self, d_model, max_seq_len=4096):
super(SinusoidalPositionalEmbedding, self).__init__()
# dimension of the embedding
self.d_model = d_model
# max_seq_len is used to compute the PE in advance
self.max_seq_len = max_seq_len
# this is the fixed positional embedding [max_seq_len, d_model]
self.position_embedding = torch.zeros(max_seq_len, d_model)
# [max_seq_len, 1]
position = torch.arange(0, max_seq_len).unsqueeze(1)
# [d_model]
denominator = torch.exp(torch.arange(0, d_model, 2) / d_model * torch.log(10000.0))
# [max_seq_len, d_model]
self.position_embedding[:, 0::2] = torch.sin(position / denominator)
self.position_embedding[:, 1::2] = torch.cos(position / denominator)
self.register_buffer('position_embedding', self.position_embedding)
def forward(self, x):
# [seq_length, d_model]
return self.position_embedding[:x.size(-1)]
Multi-Head Attention¶
The key advancement of the transformer structure is the wide usage of attentions, specifically, multi-head attention. Assume a key matrix $K \in \mathbb{R}^{N \times d_k}$, a value matrix $V \in \mathbb{R}^{N \times d_v}$, and a query matrix $Q \in \mathbb{R}^{M \times d_k}$. The attention can be computed as $$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}}) V, \in \mathbb{R}^{M \times d_v}.$$
To improve the capacity of the attention layer, the transformer employs the multi-head attention layers. Intuitively, the query, key, value are projected into $h$ different subspaces and the attention is performed in each projected subspace. In the transformer architecture, all layers' inputs and outputs are of $d_{model}$-dimensional. Thus all queries, keys, and values are of $d_{model}$ dimensional. We will adopt the linear project matrices $W^K_i \in \mathbb{R}^{d_{model} \times d_{k}}, W^Q_i \in \mathbb{R}^{d_{model} \times d_{k}}, W^V_i \in \mathbb{R}^{d_{model} \times d_{v}}$, for $i=1,...,h$. To ensure the output is still of $d_{model}$ dimensional, we linearly project it again using an output matrix $W^O \in \mathbb{R}^{h d_{v} \times d_{model}}$. Therefore, the formulation of the multi-head attention layer is $$ MultiHeadAttention(Q, K, V) = Concat(Head_1, ..., Head_h) W^O ,$$ $$ Head_i = Attention(Q W^Q_i, K W^K_i, V W_V^i), for \; i=1,...,h $$
In the transformer implementation, we can use one matrix $W^K \in \mathbb{R}^{d_{model} \times h d_k}$ to store the concatenation of $W^K_1, ..., W^K_h$, similarly for $W^Q$ and $W^V$. The original transformer paper used $d_k = d_v = d_{model} / h$ and $ h =8$. Finally, the dropout is applied to the output of the attention layer.

Computational Cost¶
Given an input sequence of length $N$, with $d_k = d_v = d_{model} / h$, computing the multi-head attention incurs computational cost $\mathcal{O}(N^2 d_{model})$. The memory cost is $\mathcal{O}(N d_{model} + N^2)$.
def attention(Q, K, V, mask=None):
# Q: [..., M, d_k]
# K: [..., N, d_k]
# V: [..., N, d_v]
# mask: [..., M, N] or None. If mask is not None, it should be a tensor with the same shape as QKt
# return: [..., M, d_v]
# [..., M, N]
QKt = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(Q.size(-1))
if mask is not None:
QKt = QKt.masked_fill(mask, float('-inf'))
# [..., M, N]
weights = torch.nn.functional.softmax(QKt, dim=-1)
return torch.matmul(weights, V)
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, d_k, d_v, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.WQ = torch.nn.Linear(d_model, num_heads * self.d_k, bias=False)
self.WK = torch.nn.Linear(d_model, num_heads * self.d_k, bias=False)
self.WV = torch.nn.Linear(d_model, num_heads * self.d_v, bias=False)
self.WO = torch.nn.Linear(num_heads * self.d_v, d_model, bias=False)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q: [..., M, d_model]
# K: [..., N, d_model]
# V: [..., N, d_model]
# mask: [..., M, N] or None.
# return: [..., M, d_model]
# [..., num_heads, M, d_k]
Q = self.WQ(Q).view(*Q.size()[:-2], Q.size()[-2], self.num_heads, self.d_k).transpose(-2, -3)
# [..., num_heads, N, d_k]
K = self.WK(K).view(*K.size()[:-2], K.size()[-2], self.num_heads, self.d_k).transpose(-2, -3)
# [..., num_heads, N, d_v]
V = self.WV(V).view(*V.size()[:-2], V.size()[-2], self.num_heads, self.d_v).transpose(-2, -3)
# [..., num_heads, M, d_v]
x = attention(Q, K, V, mask=mask)
# [..., M, num_heads * d_v]
x = x.transpose(-2, -3).contiguous().view(x.size()[:-3], -1, self.num_heads * self.d_v)
# [..., M, d_model]
x = self.WO(x)
return self.dropout(x)
Transformer¶

Encoder¶
In the encoder, each layer is a stack of the multi-head self attention layer with a feedforward layer. In each sublayer, the residual connection is added and the layer norm is applied: $LayerNorm(x + Sublayer(x))$. The encoder is a stack of $N$ layers.
The feedforward layer implements a two-layer ReLU network, with the following formulation $$FFN(x) = max(0, x W_1 + b_1) W_2 + b_2.$$
Computational cost¶
Given an input sequence of length $N$, the computational cost of the feedforward layer is $\mathcal{O}(N \times d_{model} \times d_{ff})$. Existing llms usually uses a large $d_{ff}$ (e.g., Nemotron-4-340B uses 73728) and the pretraining sequence length is usually smaller (e.g., Nemotron-4-340b uses 4096). Thus the feedforward layer becomes the main computational bottleneck in the transformer architecture. For longer contexts, e.g., $N = 128k$, the attention layer becomes the main computational bottleneck.
class FeedforwardLayer(torch.nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(FeedforwardLayer, self).__init__()
self.ff_1 = torch.nn.Linear(d_model, d_ff)
self.ff_2 = torch.nn.Linear(d_ff, d_model)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x):
x = self.ff_1(x)
x = torch.nn.functional.relu(x)
x = self.ff_2(x)
return self.dropout(x)
class EncoderLayer(torch.nn.Module):
def __init__(self, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.d_ff = d_ff
self.multi_head_attention = MultiHeadAttention(d_model, d_k, d_v, num_heads, dropout)
self.feedforward = FeedforwardLayer(d_model, d_ff, dropout)
self.norm_1 = torch.nn.LayerNorm(d_model)
self.norm_2 = torch.nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# x: [..., M, d_model]
# mask: [..., M, M] or None. In the encoder, the mask is used to mask out the padding tokens.
x = x + self.multi_head_attention(x, x, x, mask=mask)
x = self.norm_1(x)
x = x + self.feedforward(x)
return self.norm_2(x)
class Encoder(torch.nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(Encoder, self).__init__()
self.layers = torch.nn.ModuleList([EncoderLayer(d_model, num_heads, d_k, d_v, d_ff, dropout)
for _ in range(num_layers)])
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask=mask)
return x
Decoder¶
In the decoder, each layer is a stack of the masked mult-head self-attention, a multi-head cross-attention, and a feedforward layer. In the multi-head cross-attention, the keys $K$ and the values $V$ are the output of the last layer of the encoder.
class DecoderLayer(torch.nn.Module):
def __init__(self, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.d_ff = d_ff
self.self_attention = MultiHeadAttention(d_model, d_k, d_v, num_heads, dropout)
self.encoder_attention = MultiHeadAttention(d_model, d_k, d_v, num_heads, dropout)
self.feedforward = FeedforwardLayer(d_model, d_ff, dropout)
self.norm_1 = torch.nn.LayerNorm(d_model)
self.norm_2 = torch.nn.LayerNorm(d_model)
self.norm_3 = torch.nn.LayerNorm(d_model)
def forward(self, x, encoder_output, encoder_mask=None, decoder_mask=None):
# mask out future tokens.
mask = 1 - torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), diagonal=1).bool()
if decoder_mask is not None:
# decoder_mask is to mask out the padding tokens in the decoder input
mask = mask & decoder_mask
x = x + self.self_attention(x, x, x, mask=mask)
x = self.norm_1(x)
# the encoder_mask is used to mask out the padding tokens in the encoder input.
x = x + self.encoder_attention(x, encoder_output, encoder_output, mask=encoder_mask)
x = self.norm_2(x)
x = x + self.feedforward(x)
return self.norm_3(x)
class Decoder(torch.nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(Decoder, self).__init__()
self.layers = torch.nn.ModuleList([DecoderLayer(d_model, num_heads, d_k, d_v, d_ff, dropout)
for _ in range(num_layers)])
def forward(self, x, encoder_output, encoder_mask=None):
for layer in self.layers:
x = layer(x, encoder_output, encoder_mask=encoder_mask)
return x
Output Embedding and Softmax¶
Given the output of the last layer in the decoder, we can use an output embedding matrix of size $\mathcal{R}^{d_{model} \times V}$ and a softmax operation to predict the next token.
The Full Encoder-Decoder Transformer¶
class EncoderDecoderTransformer(torch.nn.Module):
def __init__(self, num_encoder_layers, num_decoder_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1, encoder_vocab_size=25600, decoder_vocab_size=25600, max_seq_len=4096):
super(EncoderDecoderTransformer, self).__init__()
self.encoder = Encoder(num_encoder_layers, d_model, num_heads, d_k, d_v, d_ff, dropout)
self.decoder = Decoder(num_decoder_layers, d_model, num_heads, d_k, d_v, d_ff, dropout)
self.encoder_embedding = Embedding(d_model, encoder_vocab_size, max_seq_len, dropout)
self.decoder_embedding = Embedding(d_model, decoder_vocab_size, max_seq_len, dropout)
self.output_layer = torch.nn.Linear(d_model, decoder_vocab_size, bias=False)
def forward(self, encoder_input, encoder_mask, decoder_input, decoder_mask):
encoder_output = self.encoder(self.encoder_embedding(encoder_input), mask=encoder_mask)
decoder_output = self.decoder(self.decoder_embedding(decoder_input), encoder_output, encoder_mask=encoder_mask, decoder_mask=decoder_mask)
output = self.output_layer(decoder_output)
return output
Objectives and Trainer¶
The transformer trains with the next-token-prediction loss, i.e., the cross entropy between the predicted full-vocab probabilities and the ground-truth target token.
def loss_fn(logits, targets):
return torch.nn.functional.nll_loss(logits, targets, reduction='none')
def run_one_epoch(data_iter, model, optimizer, pad_token=0):
for batch in data_iter:
# [B, N], [B, M], [B, M]
encoder_input, decoder_input, targets = batch
encoder_mask = (encoder_input != pad_token).unsqueeze(-2)
decoder_mask = (decoder_input != pad_token).unsqueeze(-2)
# [B, M, V]
model_outputs = model(encoder_input, encoder_mask, decoder_input, decoder_mask)
logits = torch.nn.functional.log_softmax(model_outputs, dim=-1)
loss = loss_fn(logits, targets).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
BERT: The Encoder-only Model¶
BERT [2] is the encoder-only model that trained by predicting randomly masked tokens. The massive scale unsupervised learning enables BERT to learn very rich and representative features that excel in various NLP tasks such as sentiment classification, summarization, entity recognition, and etc.
class BERTEncoderModel(torch.nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1, vocab_size=25600, max_seq_len=4096):
super(BERTEncoderModel, self).__init__()
self.num_layers = num_layers
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.d_ff = d_ff
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.embedding = Embedding(d_model, vocab_size, max_seq_len, dropout)
self.encoder = Encoder(num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout)
self.output_layer = torch.nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x, mask=None):
x = self.embedding(x)
x = self.encoder(x, mask=mask)
return self.output_layer(x)
def loss_fn(logits, targets):
return torch.nn.functional.nll_loss(logits, targets, reduction='none')
def run_one_epoch(data_iter, model, optimizer, pad_token=0, mask_token=1):
for batch in data_iter:
# [B, N]
mask = ((batch != pad_token) & (batch != mask_token))
# [B, N, V]
model_outputs = model(batch, mask.unsqueeze(-2))
logits = torch.nn.functional.log_softmax(model_outputs, dim=-1)
loss = loss_fn(logits, batch).mul(mask).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
GPT: The Decoder-only Model¶
GPT [3] uses only the decoder to predict the next token in a sequence. Since the decoder respects the causal temporal structure, a decoder-only model is also called a causal language model. The GPT models have powered the developments of large language models, such as GPT-3 [4] and GPT-4 [5].
Different with the decoder layer in the encoder-decoder transformer, a GPT decoder layer contains only a masked multi-head attention layer and a feedforward layer.
class GPTDecoderLayer(torch.nn.Module):
def __init__(self, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.d_ff = d_ff
self.self_attention = MultiHeadAttention(d_model, d_k, d_v, num_heads, dropout)
self.feedforward = FeedforwardLayer(d_model, d_ff, dropout)
self.norm_1 = torch.nn.LayerNorm(d_model)
self.norm_2 = torch.nn.LayerNorm(d_model)
def forward(self, x, decoder_mask=None):
# mask out future tokens.
mask = 1 - torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), diagonal=1).bool()
if decoder_mask is not None:
# decoder_mask is to mask out the padding tokens in the decoder input
mask = mask & decoder_mask
x = x + self.self_attention(x, x, x, mask=mask)
x = self.norm_1(x)
x = x + self.feedforward(x)
return self.norm_2(x)
class GPTDecoder(torch.nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1):
super(Decoder, self).__init__()
self.layers = torch.nn.ModuleList([GPTDecoderLayer(d_model, num_heads, d_k, d_v, d_ff, dropout)
for _ in range(num_layers)])
def forward(self, x, decoder_mask=None):
for layer in self.layers:
x = layer(x, decoder_mask=decoder_mask)
return x
class GPTDecoderModel(torch.nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout=0.1, vocab_size=25600, max_seq_len=4096):
super(GPTDecoderModel, self).__init__()
self.num_layers = num_layers
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_k
self.d_v = d_v
self.d_ff = d_ff
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.embedding = Embedding(d_model, vocab_size, max_seq_len, dropout)
self.decoder = GPTDecoder(num_layers, d_model, num_heads, d_k, d_v, d_ff, dropout)
self.output_layer = torch.nn.Linear(d_model, vocab_size, bias=False)
def log_likelihood(self, x, y, decoder_mask=None):
logits = self.forward(x, decoder_mask=decoder_mask)
# [..., seq_len, vocab_size]
probs = torch.nn.functional.log_softmax(logits, dim=-1)
return torch.gather(probs, -1, y.unsqueeze(-1)).squeeze(-1)
def forward(self, x, decoder_mask=None):
x = self.embedding(x)
x = self.decoder(x, decoder_mask=decoder_mask)
return self.output_layer(x)
# def loss_fn(logits, targets):
# return torch.nn.functional.nll_loss(logits, targets, reduction='none')
def run_one_epoch(data_iter, model, optimizer, pad_token=0):
for batch in data_iter:
# [B, N]
mask = (batch[..., 1:] != pad_token)
# [B, N, V]
loss = -model.log_likelihood(batch[..., :-1], batch[..., 1:]).mul(mask).sum(-1).mean()
# model_outputs = model(batch, mask.unsqueeze(-2))
# logits = torch.nn.functional.log_softmax(model_outputs, dim=-1)
# loss = loss_fn(logits, batch).mul(mask).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
Multi Query Attention (MQA) and Grouped Query Attention (GQA)¶
In the multi-head attention, the size of the keys is $batch\_size \times seq\_len \times num\_heads \times d_k = batch\_size \times seq\_len \times d_{model}$. Simiarly for the size of the values. Take the Nemotron-4-340B as one example, the key and value matrix requires the following number of bits $$ batch\_size \times 4096 tokens \times 18432 dims \times 96 layers \times 2 \times 16 bitsPerBfloat16 \approx 2^{37} = batch\_size \times 27 GBs $$
In decoding, one A100 can only support $batch\_size=1$ since its memory is around 40GBs. In fact, one A100 is still not enough since other operations require memory as well. Beyond that, moving the KV-cache between the memory and the computational unit is severely limited by the memory bandwidth due to the size of the KV cache.
To reduce the memory overhead, multi-query attention (MQA) [8] proposes to use the same single key-value pair for all heads. In consequence, the size of keys will become $batch\_size \times seq\_len\times d_k $, which is $num\_heads $ smaller. However, MQA usually incur quality degradation.
To mitigate the performance impact, Grouped Query Attention [9] (GQA) splits the query heads into groups, and use the same key-value pair within each group. In consequence, GQA is an intermediate form bewtween MQA and MHA: it is MQA when num_groups=1; it is MHA when num_groups=num_heads. Experiments showed that GQA leads to competitive performances with relatively small num_groups (e.g., Nemotron-4-340B uses num_groups=8 while num_heads=96). Thus GQA has been becoming the (almost) default attention method in latest large language models including Nemotron-4-340B and the LLama3 series.

class GroupQueryAttention(torch.nn.Module):
def __init__(self, d_model, d_k, d_v, num_heads, num_kv_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.d_k = d_k
self.d_v = d_v
self.WQ = torch.nn.Linear(d_model, num_heads * self.d_k)
self.WK = torch.nn.Linear(d_model, num_kv_heads * self.d_k)
self.WV = torch.nn.Linear(d_model, num_kv_heads * self.d_v)
self.WO = torch.nn.Linear(num_heads * self.d_v, d_model)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q: [..., M, d_model]
# K: [..., N, d_model]
# V: [..., N, d_model]
# mask: [..., M, N] or None.
# return: [..., M, d_model]
# [..., num_heads, M, d_k]
Q = self.WQ(Q).view(*Q.size()[:-2], Q.size()[-2], self.num_heads, self.d_k).transpose(-2, -3)
Q = Q.view(*Q.size()[:-2], self.num_heads // self.num_kv_heads, self.num_kv_heads, self.d_k)
# [num_heads/num_kv_heads, ..., num_kv_heads, M, d_k]
Q = Q.permute([-3] + list(range(len(Q.size())-3)) + [-2, -1])
# [..., num_kv_heads, N, d_k]
K = self.WK(K).view(*K.size()[:-2], K.size()[-2], self.num_kv_heads, self.d_k).transpose(-2, -3)
# [..., num_kv_heads, N, d_v]
V = self.WV(V).view(*V.size()[:-2], V.size()[-2], self.num_kv_heads, self.d_v).transpose(-2, -3)
# [num_heads/num_kv_heads, .., num_kv_heads, M, d_v]
x = attention(Q, K, V, mask=mask)
x = x.permute(list(range(1, len(x.size())-2)) + [0, -2, -1])
# [..., num_heads, M, d_v]
x = x.view(x.size()[-3], self.num_heads, x.size(-2), self.d_v)
# [..., M, num_heads * d_v]
x = x.transpose(-2, -3).contiguous().view(x.size()[:-3], -1, self.num_heads * self.d_v)
# [..., M, d_model]
x = self.WO(x)
return self.dropout(x)
Rotary Positional Embedding¶
The sinusoidal positional embedding tracks postional information using sinusoidal signals. Consider two tokens $x_m$ and $x_n$, their interaction in the multi-head attention determines their embedding relationships. Specifically, for the $i$-th head, the attention weight can be computed as $$ ( Embedding(x_m, m)^t W^Q_i ) ( Embedding(x_n, n)^t W^K_i )^t = Embedding(x_m, m)^t W^Q_i (W^K_i)^t Embedding(x_n, n) $$ We would like the interaction to be positional relative, i.e., it depends only on the token $x_m, x_n$ and their relative position $m-n$. However, the sinusoidal positional embedding is not positional relative.
The Rotary Positional Embedding (RoPE) [10] considers the fused operation of both the embedding and the attention product. It propose to modify the attention weights as
$$ \left[ R(m; d, \theta_{1:d/2}) W^Q_i TokenEmbedding(x_m) \right]^{t} \left[ R(n; d, \theta_{1:d/2}) W^K_i TokenEmbedding(x_n) \right] $$Here $R(m; d, \theta_{1:d/2}) $ is a block-diagonal matrix of $2 \times 2 $ blocks. Each block is a rotary matrix with degree $m \theta_i$. The nice property of rotary matrices are that $R(m; d, \theta_{1:d/2})^t R(n; d, \theta_{1:d/2})$ depends only on $ m - n$, since it can be thinking of rotating $m \theta_i$ in one direction and then rotate $n \theta_i$ in the opposite direction.

In consequence, RoPE ensures the inter-token interactions depends only on the tokens and their relative positions. The RoPE paper suggests to set $\theta_i = 10000^{-2(i-1)/d}, i=1,...,d/2$, similarly to the transformer.
While RoPE is named as a positional embedding method, it is more of a modified attention mechanism. Different with sinusoidal positional embeddings, RoPE is usually applied to all attention layers.
def rope_attention(Q, K, V, mask=None):
# Q: [..., M, d_k]
# K: [..., N, d_k]
# V: [..., N, d_v]
# mask: [..., M, N] or None. If mask is not None, it should be a tensor with the same shape as QKt
# return: [..., M, d_v]
# [..., M, N]
# [d_k]
thetas = torch.zeros(Q.size(-1), dtype=Q.dtype, device=Q.device)
thetas[::2] = torch.arange(0, Q.size(-1) / 2, dtype=Q.dtype)
thetas[1::2] = torch.arange(0, Q.size(-1) / 2, dtype=Q.dtype)
q_theta = torch.arange(Q.size(-2), dtype=Q.dtype).unsqueeze(-1) * thetas
k_theta = torch.arange(K.size(-2), dtype=Q.dtype).unsqueeze(-1) * thetas
flip_Q = torch.zeros_like(Q)
flip_Q[..., ::2] = -Q[..., 1::2]
flip_Q[..., 1::2] = Q[..., ::2]
rotate_Q = Q * torch.cos(q_theta) + flip_Q * torch.sin(q_theta)
flip_K = torch.zeros_like(K)
flip_K[..., ::2] = -K[..., 1::2]
flip_K[..., 1::2] = K[..., ::2]
rotate_K = K * torch.cos(k_theta) + flip_K * torch.sin(k_theta)
QKt = torch.matmul(rotate_Q, rotate_K.transpose(-2, -1)) / np.sqrt(rotate_Q.size(-1))
if mask is not None:
QKt = QKt.masked_fill(mask, -np.inf)
# [..., M, N]
weights = torch.nn.functional.softmax(QKt, dim=-1)
return torch.matmul(weights, V)
class RoPEGroupQueryAttention(torch.nn.Module):
def __init__(self, d_model, d_k, d_v, num_heads, num_kv_heads, dropout=0.1, use_rope_attention=True):
super(RoPEGroupQueryAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.d_k = d_k
self.d_v = d_v
# In cross-attention, we use the standard multi-head attention and set it False.
self.use_rpoe_attention = use_rope_attention
self.attention_fn = rope_attention if use_rope_attention else attention
self.WQ = torch.nn.Linear(d_model, num_heads * self.d_k)
self.WK = torch.nn.Linear(d_model, num_kv_heads * self.d_k)
self.WV = torch.nn.Linear(d_model, num_kv_heads * self.d_v)
self.WO = torch.nn.Linear(num_heads * self.d_v, d_model)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q: [..., M, d_model]
# K: [..., N, d_model]
# V: [..., N, d_model]
# mask: [..., M, N] or None.
# return: [..., M, d_model]
# [..., num_heads, M, d_k]
Q = self.WQ(Q).view(*Q.size()[:-2], Q.size()[-2], self.num_heads, self.d_k).transpose(-2, -3)
Q = Q.view(*Q.size()[:-2], self.num_heads // self.num_kv_heads, self.num_kv_heads, self.d_k)
# [num_heads/num_kv_heads, ..., num_kv_heads, M, d_k]
Q = Q.permute([-3] + list(range(len(Q.size())-3)) + [-2, -1])
# [..., num_kv_heads, N, d_k]
K = self.WK(K).view(*K.size()[:-2], K.size()[-2], self.num_kv_heads, self.d_k).transpose(-2, -3)
# [..., num_kv_heads, N, d_v]
V = self.WV(V).view(*V.size()[:-2], V.size()[-2], self.num_kv_heads, self.d_v).transpose(-2, -3)
# [num_heads/num_kv_heads, .., num_kv_heads, M, d_v]
x = self.attention_fn(Q, K, V, mask=mask)
x = x.permute(list(range(1, len(x.size())-2)) + [0, -2, -1])
# [..., num_heads, M, d_v]
x = x.view(x.size()[-3], self.num_heads, x.size(-2), self.d_v)
# [..., M, num_heads * d_v]
x = x.transpose(-2, -3).contiguous().view(x.size()[:-3], -1, self.num_heads * self.d_v)
# [..., M, d_model]
x = self.WO(x)
return self.dropout(x)
Tokenization¶
Decoding¶
Given a trained model, we can use it to generate language responses, i.e., decoding. For simiplicity, we consider a decoder-only GPT model, while the decoding mechanism for encoder-decoder Transformer models are the same. Various decoding methods exist, including random sampling, top-k sampling, top-p sampling, greedy decoding, and beam search when we can keep multiple decoding trajectories at the same time.
Greedy Sampling¶
Decoding proceeds one token by one token, i.e., each token is sampled every time and the next token generation depends on the previous sampled tokens. Greedy sampling generates tokens by greedily selecting the most probably token at the current prediction step. In this way, greedy sampling is myopic and usually not optimal in the long shot.
Random Sampling, Top-K Sampling, Top-P Sampling¶
While greedy sampling is named "sampling", it is a deterministic approach. The language model is a probability distribution over possible decoding sequences. Thus it naturally supports "true sampling" methods. Random sampling, respects the LM's predicted probability and samples the token accordingly. Further analysis showed that random sampling might sample very unlikely tokens at some locations which could harm the whole sampling sequence catastrophically. Top-K and Top-P sampling mitigates this issue by allowing the sampling happening only in top tokens, eliminating the possibility of sampling very unlikely tokens. By their name, top-K sampling samples from the top-K tokens; top-P sampling samples from the top tokens up to the cumulative probability of $P$.
def decode(model: GPTDecoderModel, input, greedy=True, temperature=1.0, top_k=0, top_p=1., pad_token=0, eos_token=1, max_length=2048):
# input: [1, N]
input = input.unsqueeze(-1)
for i in range(max_length):
mask = (input != pad_token)
output = model(input, decoder_mask=mask)[..., -1, :]
# [1, V]
logprobs = torch.nn.functional.log_softmax(output / temperature, dim=-1)
if top_p < 1.0:
sorted_logprobs, sorted_indices = torch.sort(logprobs, descending=True)
cumulative_probs = torch.cumsum(sorted_logprobs.exp(), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logprobs[:, indices_to_remove] = -float('Inf')
if top_k > 0:
top_k = min(top_k, logprobs.size(-1))
indices_to_remove = logprobs < torch.topk(logprobs, top_k)[0][..., None]
logprobs[indices_to_remove] = -float('Inf')
# [1]
if greedy:
next_token = logprobs.argmax()
else:
next_token = torch.multinomial(logprobs.exp(), 1).squeeze(-1)
input = torch.cat([input, next_token], dim=-1)
if next_token == eos_token:
break
return input
KV Cache¶
The decoding process involves hundreads of or thousands of steps of model forwards (model.forward). Running the model forwards from scratch in each step is a waste of resources because many computations in the attention layer have been pre-computed. Specifically, the Keys and Values of existing tokens have been computed in previous iterations, which are necessary to compute the attention in the current decoding iteration. These keys and values can be stored in a cache to prevent repetitive computations. This leads to the KV Cache.
Consider an input sequnce of length N and decoding for one single token, the computational cost of computing one attention layer's outputs without the KV cache is $ \mathcal{O}(N^2 \times d_{model} )$. With the KV cache, the computational cost is $\mathcal{O}(N \times d_{model})$.
In-flight Batching¶
In the former implementation, we assume each input contains only one example. When the GPU memory allows, running decoding for multiple examples at the same time is more efficient. However, it is not ideal to use a predefined batch of inputs, because different example might stop generation (i.e., encountering the EOS token) at different time. If some example stops generation early, waiting for other examples to finish will be a waste of resources.
In-flight batching resolves this issue by dynamically allocating batching in each generation step instead of each generation batch. When one example stops generation in a given batch, the example is poped out of the batch and a new example is added into the batch. In this way, the model forwarding always has the same efficient batch size.
Selecting which example to add to the batch can have an impact as well. If all the examples in the batch are long sequences (prompt + generated responses), we would like to select an example with a long prompt to minimize the padding, which is also a waste of compute. If all the examples in the batch are short sequences, we would like to select an example with a short prompt to minimize the padding of other examples. Overall, the less the padding in the batch, the less the wasted computation in the model forwarding.
Beam Search¶
When the computation and memory allows, we can keep multiple partial sequences in the decoding process. This has the potential of finding better responses than greedy sampling or random sampling. Beam search is such a technique that keeps $B$ sequences. In each step, each sequence selects their top-B next tokens. Overall, this forms $B^2$ sequences. The top-B sequences in these $B^2$ sequences are kept to the next iteration. In Beam search, the average log likelihood is usually used to determine the goodness of the sequence. Furthermore, to take the sequence length into account, length penalty can be applied. Overall, the goodness of each sequence can be represented as $$ \frac{\log \pi(y_{1:i} | x)}{i^{l}}$$ $l=1$ by default. Smaller $l$ favors shorter sequences. For example, $l=0$ corresponds to the sum log likelihood.
def beam_search(model: GPTDecoderModel, input, beam, length_penalty=1., temperature=1.0, pad_token=0, eos_token=1, max_length=2048):
# [B, N]
N = input.size(0)
input = input.unsqueeze(0).expand(beam, -1)
logprobs = torch.zeros_like(input).float()
for _ in range(max_length):
mask = (input != pad_token)
# Only the last token is used for the prediction
output = model(input, decoder_mask=mask)[:, -1, :]
# [B, V]
next_logprobs = torch.nn.functional.log_softmax(output / temperature, dim=-1)
# [B, B], [B, B]
top_k_logprobs, top_k_indices = torch.topk(next_logprobs, beam)
# concatenate the existing tokens and the new predicted tokens
# [N+1, B x B]
candidate_logprobs = torch.cat([logprobs.unsqueeze(-1).expand(-1, -1, beam),
top_k_logprobs.unsqueeze(-2).expand(-1, N, -1)], dim=-2).transpose(0, 1).contiguous().view(N, -1)
# [N+1, B x B]
candidate_tokens = torch.cat([input.unsqueeze(-1).expand(-1, -1, beam),
top_k_indices.unsqueeze(-2).expand(-1, N, -1)], dim=-2).transpose(0, 1).contiguous().view(N, -1)
# Mask out the following tokens after the EOS token.
# [N+1, B x B]
mask = (candidate_tokens == eos_token)
mask = torch.cumsum(mask, dim=0) > 0
# [B x B]
length = mask.ne().float().sum(0)
# [N+1, B x B]
candidate_logprobs = candidate_logprobs.masked_fill(mask, 0.)
candidate_tokens = candidate_tokens.masked_fill(mask, eos_token)
# get the top beam indices based on the goodness of the candidates
# [B]
_, top_B_indices = torch.topk(candidate_logprobs.sum(0) / length.exp(length_penalty), beam)
# update the logprobs and the input
logprobs = candidate_logprobs[:, top_B_indices].transpose()
input = candidate_tokens[:, top_B_indices].transpose()
# Check if all the beams have reached the EOS token. If true, break out of the iteration.
if (input[:, -1] == eos_token).all():
break
return input
References¶
[1] Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).
[2] Devlin, Jacob. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018).
[3] Radford, Alec. "Improving language understanding by generative pre-training." (2018).
[4] Brown, Tom, et al. "Language models are few-shot learners." Advances in neural information processing systems 33 (2020): 1877-1901.
[5] Achiam, Josh, et al. "Gpt-4 technical report." arXiv preprint arXiv:2303.08774 (2023).
[6] Adler, Bo, et al. "Nemotron-4 340B Technical Report." arXiv preprint arXiv:2406.11704 (2024).
[7] Dubey, Abhimanyu, et al. "The llama 3 herd of models." arXiv preprint arXiv:2407.21783 (2024).
[8] Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
[9] Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
[10] Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary position embedding." Neurocomputing 568 (2024): 127063.
Cite¶
@article{sun2025transformer,
title = "Show Me The Code: Transformers",
author = "Sun, Shengyang",
journal = "ssydasheng.github.io",
year = "2025",
url = "https://ssydasheng.github.io/jekyll/transformer/"
}