Transformers Inference Optimization Toolset
Large Language Models are pushing the boundaries of artificial intelligence, but their immense size poses significant computational challenges. As these models grow, so does the need for smart optimization techniques to keep them running efficiently on modern hardware.
In this post, we’ll explore key optimization strategies that are making LLMs faster and more memory-efficient. We’ll start with a brief look at GPU memory hierarchy, which forms the foundation for many of these techniques. Then, we’ll explore algorithms that allow LLMs to process information more quickly and handle longer contexts. Understanding these techniques offers valuable insights helping to unlock the full potential of Large Language Models.
The idea of this post is not just to discuss transformer-specific optimizations, since there are plenty of resources, where one can examine every inch of transformer to make it faster (my favourite one is the “Let’s reproduce GPT-2” by Andrej Karpathy). The main goal is to lower the entry barrier for those curious researchers who are currently unable to piece together the huge number of articles and papers into one picture.
A lot of optimization techniques will be left out, like for example quantization methods, which are relatively diverse and deserve a separate post. Also we’ll mostly discuss transformer inference and won’t mention some training tricks, such as mixed-precision training, gradient checkpointing or sequence packing. But even so a lot of optimizations from this post could be applied to training as well.
GPU architecture overview
To tackle language model speedup problem, first we need to understand the concept of the hardware we work on. While Google’s TPUs and Apple silicon chips are rising up, NVIDIA’s GPUs still dominate the market, so they’ll be the subject of our in-depth look.
Graphic processor unit performs all of the computations by multiple streaming multiprocessors (SM) (these are similar to the cores in the CPU). SM is basic GPU building block: it has its own instruction schedulers and various instruction execution pipelines. Modern GPUs are also equipped with special off-chip memory called high bandwidth memory (HBM), where data is initially stored and ultimately written back. Unlike to the system dynamic random access memory (DRAM), which is controlled by CPU and typically optimized for low latency access, HBM is physically bonded to the GPUs in stacked layers with thousands of pins and provides massively parallel data throughput by design.
Streaming multiprocessors access data and code from HBM via the L2 cache. It acts as an intermediate level between off-chip and on-chip memory and caches data that be shared among multiple SMs. It also situated in the path of data moving between devices. And finally, each SM has its own L1 cache and shared memory (SRAM), a low-latency on-chip memory caches: they are order of magnitude faster than HBM but many orders of magnitude smaller in size. L1 cache is managed by the GPU hardware, while SRAM can be explicitly managed by the programmer through NVIDIA tools.
The GPUs can communicate to each other with a high bandwidth interconnect called NVLink, and they can talk to the outside world with a PCIe bus (a high-speed bus standard, common on motherboards to transfer data) or a special ethernet alternative called Infiniband. Usually, 8 GPUs are packed into a single node. Feel free to check out my post on parallelization strategies to learn more on multi-device training.
Schematic example of memory architecture with 2 GPU devices.
Now, when it comes to GPU capability, we must look at three things:
- Compute performance measured by the number of trillion float operations per second (TFLOPS).
- GPU memory required to store model parameters, hidden activations and cache values, measured in GBs. For instance, GPT-3 has 175 billion parameters, so we need 350 GBs of memory just to keep them on device in fp16.
- Memory bandwidth measured in GB/s - the speed of bytes movement from GPU to processing units.
GPU capabilities grow exponentially fast. According to NVIDIA documentation, T4 graphics card released in 2018 had 65 TFLOPs, 40 SMs with 64KB L1 cache each, 4MB L2 cache with 1.3TB/s bandwidth and 16GB HBM with 300 GB/s bandwidth. After just 2 years A100 was released with 312 TFLOPs and 108 SMs with 192KB of L1, 40MB of L2 cache and 80 GB of HBM with 1.55 TB/s bandwidth. Compare those numbers to the latest B100 card, which can perform 1.8 PFLOPs and which HBM has a capacity of 192 GB and a throughput of 8 TB/s.
Memory bandwidth vs compute performance rapid growth. Note that both axes are log-scale.1
Time required for memory accesses can vary depending on the devices, their modifications and infrastructure setups. But the main point to remember is that if we compare throughput numbers, we will see that some of them differ by orders of magnitude:
- To read data from L1 cache / SRAM:
x
ns. - To read data from L2 cache:
2-3x
ns. - To read data from HBM memory:
10x
ns. - To share data between GPUs with NVLink (both ways):
50-60x
ns. - To load data from CPU to GPU DRAM through PCIe bus:
~300x
ns.
These numbers show us that while the number of operations per second matters, the operand placement can be even more important when we optimizing for inference speed. Keep in mind that the slower memory always dominates performance bottlenecks.
Arithmetic intensity vs ops:byte
Depending on the balance of computation and memory accesses, operations can be classified as follows:
- Compute-bound: the time spent on arithmetic operations exceeds time spent for other operations such as memory accesses. Typical examples are linear layer with large inner dimension or convolutional layer with large number of channels.
- Memory-bound: the time taken by memory accesses exceeds computation time. Most operations are memory bound, e.g. elementwise operations (activation functions, dropouts) or reductions (sum, softmax, normalization).
- Overhead-bound: everything else, such as communication-bound, interpreter-bound, etc. We won’t discuss it in this post, however I strongly advice to take a look Making Deep Learning Go Brrrr From First Principles blogpost to understand GPU mechanisms and why in most of the cases our bottleneck may not be related to them at all.
The balance between first two is commonly measured by the arithmetic intensity, which is the number of arithmetic operations per byte of memory access required to run the computation. For example, if we apply ReLU activation to an input tensor $x$ (assuming it’s in half-precision), we need to
- read 2 bytes
- make 1 comparison
- write 2 bytes
for each element of the tensor. Regardless of the size of $x$, arithmetic intensity for ReLU is equal to $\frac{\# \operatorname{flops}}{\# \operatorname{bytes}}=\frac{1}{4}$. Again, this means that for each operation we need to make 4 memory accesses.
Arithmetic intensity is commonly compared to a hardware specific ops:byte
ratio to find if we are in compute- or memory-bound scenario. To explain how it works, let’s take for example a linear layer forward pass on A100 GPU. Given an input batch $x \in \mathbb{R}^{B \times d}$ and weight matrix $\mathbf{W} \in \mathbb{R}^{d \times d}$ (here $B$ is a batch size and $d$ is an embedding dimension) linear layer basically represents a matrix multiplication $x\mathbf{W}$. We can calculate that linear layer computation requires $2Bd^2$ flops.2 Hence the compute-time for A100 will be
At the same time we need to read $2d^2$ bytes from memory to load weight matrix $\mathbf{W}$ (again under the condition that we work with fp16/bf16). Also, just for simplicity let’s say that $B \ll d$ and we can neglect the loading time of $x$ compared to weight matrix $\mathbf{W}$. Model parameters are usually stored at HBM, therefore
\[T_{\operatorname{memory}} = \frac{\# \operatorname{bytes}}{\operatorname{memory bandwidth}} = \frac{2d^2}{1.55 \cdot 10^{12}} s.\]Recall that arithmetic intensity is equal to $\frac{\# \operatorname{flops}}{\# \operatorname{bytes}}$, while ops:byte
is given by $\frac{\operatorname{compute performance}}{\operatorname{memory bandwidth}}$. To find the bottleneck for our model we look at the ratio of these two terms, which is
This means that until our batch size is smaller than $200$ our system performance is memory-bound. Enlarging input batch to a value greater than $200$ increases the computation time, while keeping the memory transfer time constant, which brings us to the compute-bound scenario.
The ops:byte
ratio analysis is useful, but keep in mind, that it assumes that a GPU workload is sufficiently large to saturate compute and memory pipelines. If the workload is not large enough, or does not have sufficient parallelism, the processor will be under-utilized and performance will be limited by latency.
High-level algorithmic optimizations
Now we are ready to delve into the specifics of transformer optimization. We’ve defined transformer architecture earlier in previous blog-posts. Let’s recall shortly that scaled dot product attention operation takes a set of queries $\mathbf{Q}$, keys $\mathbf{K}$ and values $\mathbf{V}$ as input and outputs
\[\operatorname{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \operatorname{softmax} \Big( \frac{\mathbf{QK}^T}{\sqrt{d}} \Big) \cdot \mathbf{V},\]where $d$ is a hidden dimensionality for queries and keys. When we work with GPT-based models, we use masked attention where $\operatorname{softmax}$ input is modified with $\text{mask}$ tensor, setting masked attention values to $-\infty$ if we don’t want to attend to corresponding tokens. Input tensors are $\mathbf{Q} \in \mathbb{R}^{L \times d}$ and $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{M \times d}$, where $L$ and $M$ are sequence lengths3.
Also let’s recap multi-head attention layer (MHA) definition:
\[\operatorname{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=[\operatorname{head}_1; \dots; \operatorname{head}_k] \cdot \mathbf{W}^O,\]where
\[\operatorname{head}_i = \operatorname{Attention}(\mathbf{QW}_i^Q, \mathbf{KW}_i^K, \mathbf{VW}_i^V), \quad i = 1, \dots, h.\]with learnable parameters $\mathbf{W}^Q_{1 \dots h}, \mathbf{W}^K_{1 \dots h}, \mathbf{W}^V_{1 \dots h}$ and $\mathbf{W}^O$. If MHA receives $\mathbf{Q} = \mathbf{K}$ (and normally $ = \mathbf{V}$), we call it multi-head self-attention, otherwise it is called multi-head cross-attention. We’ll focus on self-attention mechanism as it’s widely used in generative LLMs.
Let’s introduce new tensor names to simplify the notation: we’ll call dot product $\mathbf{S} := \mathbf{QK}^T \in \mathbb{R}^{L \times L}$, normalized attention weights $\mathbf{P} := \operatorname{softmax}(\mathbf{S} \circ \text{mask}) \in \mathbb{R}^{L \times L}$ ($\text{mask}$ is broadcastable to $\mathbf{S}$) and output $\mathbf{O} := \mathbf{PV} \in \mathbb{R}^{L \times d}$.
KV Cache
When we work with models like GPT, text generation occurs in two stages:
- Prefill - the model ingests large chunk of our prompt tokens in parallel, computing all hidden states and outputs in one pass.
- When prefill is finished auto-regressive decoding is launched. Decoding is in general more time-consuming than prefill due to its sequential nature: response tokens are always generated one after another.
Representation of causal self-attention during text generation for different sequence lengths $L$. Scaling coefficient $\sqrt{d}$ is omitted. Attention values can be computed only once for the whole bunch of prompt tokens, but then sequential one-by-one computations are required to generate response tokens.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@jit
def dot_product_attention(query, key, value, mask=None):
d = query.shape[-1]
# attn_logits shape is [batch..., num_heads, q_seq_len, kv_seq_len]
attn_logits = jnp.einsum('...lhd,...mhd->...hlm', query, key)
attn_logits = attn_logits / jnp.sqrt(d) # normalize logits
if mask is not None:
big_neg = jnp.finfo(attn_logits.dtype).min
attn_logits = jnp.where(mask, big_neg, attn_logits)
# logits -> weights
attention = nn.softmax(attn_logits, axis=-1)
# return weighted sum over values for each query position
output = jnp.einsum('...hlm,...mhv->...lhv', attention, value)
return output, attention
Let’s check whether MHA inference is compute- or memory-bound.
Computational complexity:
- To compute query $\mathbf{Q}$ we multiply input matrix $x \in \mathbb{R}^{L \times d}$ with matrices $\mathbf{W}^Q_{1 \dots h} \in \mathbb{R}^{d \times \frac{d}{h}}$ across $h$ heads which takes $\mathcal{O}(Ld^2)$ operations. The same amount of compute is needed for $\mathbf{K}$ and $\mathbf{V}$.
- Attention computation requires $\mathcal{O}(L^2d)$ for both $\mathbf{S} = \mathbf{QK}^T$ and $\mathbf{O}=\mathbf{PV}$.
In total, we need to perform $\mathcal{O}(Ld^2 + L^2d)$ operations.
Memory accesses:
- Input $x$ and intermediate tensors $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, $\mathbf{O}$ occupy $\mathcal{O}(Ld)$ bytes.
- Attention logits $\mathbf{S}$ and weights $\mathbf{P}$ take $\mathcal{O}(L^2h)$ bytes for all heads in total.
- Projection weights $\mathbf{W}^{Q,K,V}_{1 \dots h}$ require $\mathcal{O}(d^2)$ memory space.
The entire memory size to be accessed is equal to the sum of the sizes of all the tensors involved, which is $\mathcal{O}(Ld + L^2h + d^2)$ bytes. Therefore, we have arithmetic intensity proportional to
\[\frac{L^2d + Ld^2}{L^2h + Ld + d^2} \xrightarrow[L \gg d]{} \frac{d}{h}\]
Arithmetric intensity vs sequence length $L$ for multi-head self-attention with embedding dimension $d=12,288$ and number of heads $h=96$ (GPT-3 scale).
We’ve noticed already that modern GPU hardware has the computational capacity orders of magnitude higher than the memory bandwidth. As the graph shows, for sufficiently large sequence length the arithmetic intensity is always larger than embedding dimension per attention head $\frac{d}{h}$, which is usually one or few hundreds.4 Hence, the arithmetic intensity is equal if not greater than ops:byte
ratio.
Generally, this would imply high algorithm efficiency, but the situation is different for the second phase, text generation. First thing to notice here is that in generation scenario there is no need to compute the attention outputs for each token in the input sequence $x$, only for the last one to decode the next token $x_{L+1}$. Thus there is no need to send the whole query $\mathbf{Q}$ vector into attention mechanism.
The second important thing is that we can reuse previously computed activations, namely we can cache $\mathbf{K}$ and $\mathbf{V}$ values during generation process, hence the naming. We store the KV cache to improve efficiency and reduce redundant computational requirements, especially for long sequences.
Representation of causal self-attention with KV cache during decoding phase. Each timestep $\mathbf{K}$ and $\mathbf{V}$ computed for the last token are added to the cache and re-used in the future steps.
Let’s get flops and memory accesses count for text generation at each generation step (we can simply multiply these counts by $L$ steps to get values for the whole sequence).
Compute:
- To compute query $\mathbf{Q}$ we multiply input vector $x \in \mathbb{R}^{d}$ with matrices $\mathbf{W}^Q_{1 \dots h} \in \mathbb{R}^{d \times \frac{d}{h}}$ across $h$ heads which takes $\mathcal{O}(d^2)$. The same is for $\mathbf{K}$ and $\mathbf{V}$ when we store the KV cache.
- Attention computation requires at most $\mathcal{O}(Ld)$ for both $\mathbf{S} = \mathbf{QK}^T$ and $\mathbf{O}=\mathbf{PV}$.
In total we need to perform $\mathcal{O}(d^2 + Ld)$ operations for each step. The number of operations stays the same for $L$ steps as in prefill stage.
Memory:
- Input $x$ and intermediate tensors $\mathbf{Q}$, $\mathbf{O}$ occupy $\mathcal{O}(d)$ bytes, but $\mathbf{K}$ and $\mathbf{V}$ in cache require $\mathcal{O}(Ld)$ space.
- Attention logits $\mathbf{S}$ and weights $\mathbf{P}$ take at most $\mathcal{O}(Lh)$ bytes across all heads.
- Projection weights $\mathbf{W}^{Q,K,V}_{1 \dots h}$ take again $\mathcal{O}(d^2)$ bytes.
Hence in total we need $\mathcal{O}(Ld + Lh + d^2)$ bytes. And finally
which is definitely smaller than ops:byte
ratio and we end up memory-bound. Although we reduced the amount of operations by a factor of $L$ by removing $L-1$ queries, the number of memory accesses did not decrease as much. The reason for that is that at each step we retrieve all the values from the KV cache, the size of which increases in proportion to the length of the sequence.
And that brings us to another drawback - the need to store KV cache requires a lot of HBM capacity, e.g. when we launch a decoding on a transformer with $n$ layers, we need to store $\mathcal{O}(Lnd)$ bytes of KV cache. Thus we either need to make sure that we have enough of memory to accommodate it or to load it from CPU DRAM, which is one or two orders of magnitude slower compared to reading from HBM.
A real world example: take A100 GPU with $80$ GB of HBM. Say, we work with GPT-3 model (I believe it can be considered large yet these days) with $n=96$ and $d=12,288$ and try to fit in the context of length $L=4096$. Then the space we need additionally is
\[\underset{\mathbf{K/V}}{2} \cdot \underset{\text{float16}}{2} \cdot \underset{\text{sequence length}}{4096} \cdot \underset{\text{number of layers}}{96} \cdot \underset{\text{embedding dimension}}{12,288} = \underset{\text{bytes}}{19,327,352,832}.\]So $18$ GB or $22.5\%$ of A100 memory space is required for KV cache of just one sequence sample. Keeping in mind, that most of the GPU space would be taken by model parameters, we can conclude that even without enlarging our batch size we may quickly run out of memory.
Multi-query / Grouped-query attention
In the standard attention mechanism, the $\mathbf{KV}$ pairs are computed for each vector $\mathbf{Q}$ independently. This means that for each token in the input sequence, a separate key-value pair is computed and cached. However, in many cases, different query vectors may share similar attention patterns, and thus, the corresponding keys and values could be reused across multiple queries. Multi-query attention (MQA)(Shazeer, 2019) shares the cached key-value pairs across multiple queries, thus substantially reducing the memory requirements associated with the KV cache during text generation.
MQA not only lowers memory consumption, but it also leads to higher inference throughput. Our computational complexity doesn’t change, because from algorithmic point of view the number of matrix multiplications stays the same and we only reuse $\mathbf{K}$ and $\mathbf{V}$ for different heads. But in terms of memory, KV cache requires only $\mathcal{O}\big( L \frac{d}{h} \big)$ space now and arithmetic intensity is proportional to
\[\frac{Ld + d^2}{L\frac{d}{h} + Lh + d^2} \xrightarrow[L \gg d]{} \frac{dh}{d+h^2} \approx h,\]meaning it is steadily growing with increasing $L$ until it reaches the plateau of few orders of magnitude. We might be still memory-bound in most cases, but with multi-query attention technique we can
- reduce the size of KV cache by a factor of $h$ and
- potentially make decoding algorithm $h$ times faster.
Arithmetric intensity vs sequence length $L$ for multi-query and multi-head self-attention mechanisms during auto-regressive generation with embedding dimension $d=12,288$ and number of heads $h=96$.
In our example above with GPT-3 model, using MQA would make KV cache $h=96$ times smaller, thus the required space would take around $200$ MB which is just $0.25\%$ of one A100 GPU memory.
Of course, such acceleration and memory reduction come with a price - we cut model parameters and therefore its potential capacity. The possible way to avoid quality degradation is to use technique, which interpolates between MHA and MQA - grouped query attention (GQA) (Ainslie et al. (2023)). With GQA we split $h$ query heads into $g$ groups, each with its own keys and values. Note that for $g=1$ GQA is equal to multi-query and for $g=h$ GQA is the same as multi-head attention. The choice of $g$ is a trade-off between memory savings and potential accuracy loss. A larger group size will result in more memory savings but may also lead to a larger approximation error in the attention computations. In practice, the optimal group size may need to be determined empirically based on the specific model architecture and the trade-off between memory efficiency and model performance.
MHA ($h=6$) vs GQA ($g=3$) vs MQA
1
2
3
4
5
6
7
8
@jit
def gqa_dot_product_attention(query, key, value, mask=None):
num_heads, num_kv_heads = query.shape[-2], key.shape[-2]
# broadcast K/V heads to match number of Q heads
num_heads_per_kv = num_heads // num_kv_heads
key = jnp.repeat(key, num_heads_per_kv, axis=-2)
value = jnp.repeat(value, num_heads_per_kv, axis=-2)
return dot_product_attention(query, key, value, mask)
Below is a comparison table with batched decoding/inference algorithms complexities for input $x \in \mathbb{R}^{B \times L \times d}$ and large context size $L$. Note, that the computation complexity is the same for all algorithms in the table! However, the real effectiveness can vary greatly depending on the setting.
Vanilla Attention | Attention with KV Cache | GQA with KV Cache | |
---|---|---|---|
Compute performance | $\mathcal{O}(BLd^2 + BL^2d)$ | $\mathcal{O}(BLd^2 + BL^2d)$ | $\mathcal{O}(BLd^2 + BL^2d)$ |
Memory bandwidth | $\mathcal{O}(BLd + BL^2h + d^2)$ | $\mathcal{O}(BL^2d + BL^2h + d^2)$ | $\mathcal{O}(BL^2d/g + BL^2h + d^2)$ |
Arithmetic intensity for large $L$ | $\mathcal{O}\big(\frac{d}{h}\big)$ | $\mathcal{O}\big(\frac{d}{d + h}\big)$ | $\mathcal{O}\big(\frac{dg}{d+hg} \big)$ |
Prefill with chunking
A large KV cache is not the only source of memory problems as the sequence length increases. During prefill phase we compute all the outputs and $\mathbf{K}\mathbf{V}$ pairs in one pass. This requires us to compute the attention matrix $\mathbf{S} \in \mathbb{R}^{L \times L}$, which depends quadratically on the context length. What if the prompt size is so large that we can’t fit in all attention weights? We can compute them by passing single tokens one-by-one as we do it in decoding stage, though this procedure is much slower since it’s memory-bound. But since we know future tokens in advance, we can feed them to model in chunks:
- Take first $C$ tokens ($C < L$) from the prompt, run them through the prefill stage and store their KV cache values. Attention weights will be $\mathbf{S} \in \mathbb{R}^{C \times C}$.
- Then apply the same procedure for the next $C$ tokens, but now use cached $\mathbf{KV}$ pairs to attend to the tokens in a previous chunk. Attention weights then will be $\mathbf{S} \in \mathbb{R}^{C \times 2C}$.
- Repeat until the whole prompt is prefilled. The maximum size of $\mathbf{S}$ at the end will be ${C \times L}$.
Prefill with chunking with $C = 4$ and prompt length $10$.
With chunking the maximum size of $\mathbf{S}$ depends linearly on $L$ multiplied by controllable constant coefficient $C$.
Sliding window attention
We can see that even with multi-query attention and prefill-chunking our KV cache and attention weights are still increasing with growing context during both prefill phase and decoding phase. If we truncate the amount of tokens each other token can attend to by some constant $L_w$, our memory requirements will not depend on the input sequence length. This is exactly what a technique called sliding window attention does: it changes attention mask from lower-diagonal matrix to a band matrix.
Mask for attention weights with $L=10$ and sliding window $L_w=4.$
This makes attention layer focus only on local context. But notice that tokens can implicitly attend to previous $n \cdot L_w$ tokens, where $n$ is a number of layers in our transformer model. This is very similar to how receptive field works in convolutional networks. Choosing the optimal window size involves a trade-off between memory efficiency and maintaining context. Larger window size preserve more context but require more memory, while smaller window size is more memory-efficient but it may lose some context.
When we use sliding window attention, we might use rolling KV cache as well. As KV cache is now restricted by a given constant $2 \cdot L_w$ and only one $\mathbf{KV}$ pair is changed at each step, we can remove pair related the oldest token that we won’t attend to anymore and replace it with a newest one. In practice we keep the write pointer at the oldest pair and move it by one after its replacement. When we reach the end of the buffer, we move it back to the first position.
Decoding with sliding window attention with $L_w=4$.
Another advantage of sliding window attention is that combining it with chunking during prefill phase does not only keep maximum size of attention matrix constant ($\mathbf{S} \in \mathbb{R}^{C \times L_w}$), but also reduces the number of dot-products to compute it.
The drawback of sliding window attention is that it may lead to degradation as not all interactions between tokens are captured. An interesting phenomenon was found by Xiao et al. (2024), which they called attention sink: keeping the $\mathbf{KV}$ of a small number of tokens in the beginning of the sequence will largely recover the performance of window attention. They observe that LLMs outputs strong attention scores towards initial tokens as a “sink” even if they are not semantically important.
Linear attention
Linear attention mechanism (Katharopoulos et al. (2020)) is an alternative family of methods to avoid $\mathcal{O}(L^2)$ scaling for long sequences. Linear attention approximates the standard attention mechanism while achieving linear time and space complexity. The key idea is in reformulating the attention operation using the associative property of matrix multiplication and kernel functions.
Kernel function $\mathcal{K}(q, k)$ can be thought of as similarity measure between pair of inputs $q$ and $k$, exactly like in any attention mechanism. To simplify computation kernel function is oftenly chosen so that it can be represented in the form of a feature map $\phi$:
\[\mathcal{K}(q, k) = \phi(q)^T \phi(k).\]If we find such feature map, it would allow us to implicitly compute similarities between queries and keys without explicitly computing the full attention matrix $\mathbf{QK^T}$.
In attention mechanism unnormalized similarity between query embedding $\mathbf{q}$ and key embedding $\mathbf{k}$ is measured as $\mathcal{K}(\mathbf{q}, \mathbf{k}) = \exp \big( \frac{\mathbf{q}^T\mathbf{k}}{\sqrt{d}} \big)$. Each element of softmax masked attention matrix $\mathbf{P} = \operatorname{softmax}(\mathbf{S} \circ \operatorname{mask})$, the normalized similarity between query row $\mathbf{Q}_i$ and key row $\mathbf{K}_j$ ($j \leq i$), can be represented as
\[\mathbf{P}_{ij}=\frac{\mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j)}{\sum_{j \leq i} \mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j)}.\]Using feature maps we can rewrite each row $i$ of dot-product attention output $\mathbf{O} = \mathbf{P} \mathbf{V}$ as
\[\begin{aligned} \mathbf{O}_{i} &= \sum_{j \leq i} \mathbf{P}_{ij} \mathbf{V}_j \\ &= \frac{\sum_{j \leq i} \mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j) \cdot \mathbf{V}_{j}}{\sum_{j \leq i} \mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j)} \\ &= \frac{\sum_{j \leq i} \phi(\mathbf{Q}_i)^T \phi(\mathbf{K}_j) \cdot \mathbf{V}_{j}}{\sum_{j \leq i} \phi(\mathbf{Q}_i)^T \phi(\mathbf{K}_j) } \\ &= \frac{ \phi(\mathbf{Q}_i)^T \cdot \color{Salmon}{\sum_{j \leq i} \phi(\mathbf{K}_j) \mathbf{V}^T_{j}} }{ \phi(\mathbf{Q}_i)^T \cdot \color{#007BA7}{\sum_{j \leq i} \phi(\mathbf{K}_j)}} \\ &= \frac{ \phi(\mathbf{Q}_i)^T \cdot \color{Salmon}{\mathbf{U}_i}}{ \phi(\mathbf{Q}_i)^T \cdot \color{#007BA7}{\mathbf{Z}_i} }. \end{aligned}\]The above equation is simpler to follow when the numerator is written in vectorized form as follows,
\[\big( \phi(\mathbf{Q})\phi(\mathbf{K})^T \big) \mathbf{V} = \phi(\mathbf{Q})\big( \phi(\mathbf{K})^T \mathbf{V} \big).\]Regardless of the value $L$ we no longer need to store the quadratically growing attention matrix, we only need $\mathcal{O}(d^2)$ space for $\mathbf{U}_L = \phi(\mathbf{K})^T \mathbf{V} \in \mathbb{R}^{d \times d}$:
Neither prefill nor decoding phase with linear attention require $\mathcal{O}(L^2)$ space anymore. Scalar denominator $\phi(\mathbf{Q}_L)^T \cdot \mathbf{Z}_L$ is omitted here.
Another interesting property emerges with introduction of feature maps: linear attention computation can be expressed recurrently. Note that we have
\[\begin{aligned} \mathbf{U}_i &= \mathbf{U}_{i-1} + \phi ( \mathbf{K}_i ) \mathbf{V}_{i}^T, \\ \mathbf{Z}_i &= \mathbf{Z}_{i-1} + \phi (\mathbf{K}_i), \end{aligned}\]assuming $\mathbf{U}_0, \mathbf{Z}_0$ are both 0-valued.
This allows us to keep only constant-sized hidden states $\mathbf{U}$ and $\mathbf{Z}$ to compute the attention during auto-regressive decoding and we don’t need to feed linearly increasing inputs to the model.
The Hedgehog & the Porcupine
The choice of the feature map $\phi$, such that
\[\phi(q) \phi(k)^T \approx \exp(qk^T)\]and is a non-trivial task. Although linear attention reduces computational complexity, it may lead to a decrease in model performance if kernel approximation doesn’t capture key properties of full attention.
Authors of the original linear attention used $\phi(x) = \operatorname{ELU}(x)+1$ as a feature map in their experiments. Another option is to use standard $\operatorname{ReLU}$ function (though it’ll set the gradients to $0$ for negative inputs). But while such choices lead to simple and effective computations, Zhang et al. (2024) showed that, unlike softmax attention, these feature maps lead to the loss of two key features associated with higher performance:
Low-entropy “spikyness”: intuitively, attentions are expected to attend only to relevant tokens and ignore irrelevant ones.
Dot-product monotonicity: attention weights increase as the dot products of their corresponding queries and keys increase. Otherwise, the lack of this monotonicity can produce unstable gradients during training.
They observed that 2nd-degree Taylor approximation of exponential function $\phi_{\text{taylor}}(\mathbf{x}) = \big[1, x_1, \dots, x_d\big] \cup \big[x_i \cdot x_j \vert i, j \in [d] \big]$ for $d$-dimensional vector $\mathbf{x}$ retains both the spikiness and monotonic properties and this corresponds to (near)-matching softmax attention performance. Unfortunately, $\phi_{\text{taylor}}(\mathbf{x})$ maps to $\mathbb{R}^{1 + d + d^2}$ space, resulting in $\mathcal{O}(L d^3)$ attention complexity, which becomes costly with growing embedding dimension.
To solve this problem they propose Hedgehog5, learnable linear layer with exponential activation function, trained to capture these properties and mimic softmax attention weights:
\[\phi_\text{mlp}(\mathbf{x}) = \exp (\mathbf{xW}).\]To learn a softmax approximation, they train $\phi_\text{mlp}(\mathbf{x})$ to minimize the cross-entropy loss between the computed linear attention weights and those that would have been computed via softmax masked attention $\mathbf{P}$:
\[\mathcal{L}_i = -\sum_{j \leq i} \mathbf{P}_{ij} \cdot \log \frac{\phi_\text{mlp}(\mathbf{Q}_i)^T\phi_\text{mlp}(\mathbf{K}_j)}{\sum_{k \leq i} \phi_\text{mlp}(\mathbf{Q}_i)^T\phi_\text{mlp}(\mathbf{K}_k)}.\]Low-level hardware optimizations
CUDA kernel fusion
We’ve observed before that bandwidth cost (moving data from one place in memory to another) is usually the most substantial factor when we talk about performance. Let’s step back a bit and take another detailed look at the GPU hardware to understand why this is the case.
GPU is designed to perform thousands of simple operations in parallel, called threads. Each thread has its own (the fastest on GPU) memory, called register. Threads which run on the same SM can be grouped into thread blocks to be executed at the same time (the number of maximum threads in a block is limited by the architecture, usually it’s 1024). Threads within a thread block can load data from global memory (HBM) into shared memory (SRAM), which is used for communication between threads, perform computations, and write results back to global memory.
When SM is given one or more thread blocks to execute, it partitions them into warps. A warp is a set of 32 threads, such that all the threads in a warp execute the same instruction. Finally, multiple thread blocks are combined to form a grid. All the blocks in the same grid contain the same number of threads. Since the number of threads in a block is limited, grids can be used for computations that require a large number of thread blocks to operate in parallel.
The computations are defined in kernels, small C++ functions, which supposed to be executed multiple times in parallel by different threads (as opposed to only once like regular C++ functions). A kernel is launched as a grid and different kernels can have different grid and block configurations.
Let’s take A100 as an example again. It has 108 SMs, each can run up to 2048 threads. Shared memory capacity for each SM is up to 192KB. This means that we can take large matrices (up to few MBs), split them up in smaller chunks that fit into A100 registers and SRAM and then do matrix multiplication at the speed of 18 TB/s, SRAM bandwidth. This in total makes GPU much more efficient than CPU for matrix multiplications.
But we have a lot of non-matmul operations in deep learning, such as normalization layers, activation functions or dropouts. Even though they only account for a small fraction of the total FLOPs, GPUs run them much slower. Fistly, because they have specialized units for matrix multiply, called Tensor Cores. That’s why matmul throughput can be up to 16× higher than non-matmul throughput, e.g. 312 TFLOPs vs 19.5 TFLOPs on A100.6 And secondly, they can be extremely memory-bound.
Let’s take $\mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{L \times L}$ from attention. For a large sequence length $L$, tensor $\mathbf{S}$ has to reside on HBM, since it would be too large to fit on SRAM. The simplest implementation of softmax requires us to
- Read $L^2$ floats of $\mathbf{S}$, compute $\exp(\mathbf{S})$, write $L^2$ floats back to HBM.
- Get sum of $\exp(\mathbf{S})$ along row axis: read $L^2$, write $L$.
- Divide $\exp(\mathbf{S})$ by denominator: read $L^2+L$, write $L^2$.
In total we need to move back and forth $5L^2+L$ floats. It would be much faster if we could just read $L^2$ floats of $\mathbf{S}$ and write $L^2$ floats of $\mathbf{P}$, making this operation more than 2.5× faster.
That’s where kernel fusion comes into play: instead of writing computation output $y=f(x)$ to low-bandwidth global memory only to read it again to get $z=g(y)$, we can implement kernel which performs multiple computations at once $z=(g \circ f)(x)$ without extra memory accesses. XLA compiler in Jax can perform simple fusions, but a programmer can also write custom CUDA kernels with Triton or Pallas.
Memory-efficient attention
Online softmax
Function $\mathbf{y} = \operatorname{softmax}(\mathbf{x})$ is defined as
\[\mathbf{y}_i = \frac{e^{\mathbf{x}_i}}{\sum_j e^{\mathbf{x}_j}}.\]1
2
3
4
@jit
def naive_softmax(logits):
exp_logits = jnp.exp(logits)
return exp_logits / exp_logits.sum()
The naive implementation of softmax scans $\mathbf{x}$ 2 times - one to calculate normalization term and another to compute output vector $\mathbf{y}$. Unfortunately, on real hardware, such implementation has a serios flaw: for $\mathbf{x}_i \geq 89$ exponentiation results in infinity for bf16 and fp32. And here’s a trick to avoid overflow: notice that for any constant $m$:
\[\begin{aligned} \operatorname{softmax}(\mathbf{x})_i & = \frac{e^{\mathbf{x}_i}}{\sum_j e^{\mathbf{x}_j}} \\ & = \frac{e^{\mathbf{x}_i}}{\sum_j e^{\mathbf{x}_j}} \cdot \frac{e^{-m}}{e^{-m}}\\ & = \frac{e^{\mathbf{x}_i-m}}{\sum_j e^{\mathbf{x}_j-m}} \\ &= \operatorname{softmax}(\mathbf{x}-m)_i . \end{aligned}\]If we set $m(\mathbf{x}) = \max_i \mathbf{x}_i$ and $\ell(\mathbf{x}) = \sum_j e^{\mathbf{x}_j - m(\mathbf{x})}$ to compute
\[\mathbf{y}_i = \operatorname{softmax}(\mathbf{x}-m(\mathbf{x}))_i = \frac{e^{\mathbf{x}_i - m(\mathbf{x})}}{\ell(\mathbf{x})}\]we implement a numerically stable version of softmax, which is sometimes called safe softmax.
1
2
3
4
@jit
def safe_softmax(logits):
exp_logits = jnp.exp(logits - logits.max())
return exp_logits / exp_logits.sum()
But stability comes with a price in a efficiency since we do one more pass over $\mathbf{x}$ now to calculate $m(\mathbf{x})$. This results in 4 memory access per vector element overall (3 loads and 1 store) and we want to improve on that.
For two vectors $\mathbf{x}^1, \mathbf{x}^2$ we can decompose statistics of concatenated vector $\mathbf{x}=[\mathbf{x}^1, \mathbf{x}^2]$ as
- $m(\mathbf{x}) = \max \big(m(\mathbf{x}^1), m(\mathbf{x}^2)\big)$
- $\ell(\mathbf{x}) = e^{m(\mathbf{x}^1) - m(\mathbf{x})} \ell(\mathbf{x}^1) + e^{m(\mathbf{x}^2) - m(\mathbf{x})} \ell(\mathbf{x}^2)$
Based on this property Milakov and Gimelshein (2018) presented online softmax, which calculates both $m(\mathbf{x})$ and $\ell(\mathbf{x})$ in one pass: initialize $m_0=-\infty$ and $\ell_0 = 0$, then for each iteration $i = 1, \dots, L$:
- $m_{i} \leftarrow \max(m_{i-1}, \mathbf{x}_i),$
- $\ell_{i} \leftarrow \ell_{i-1} e^{m_{i-1}-m_i} + e^{\mathbf{x}_i - m_i}.$
This algorithm keeps the maximum value
\[m_{i} = m([\mathbf{x}_{1}, \dots, \mathbf{x}_{i}])\]and the normalization term
\[\ell_{i}=\ell([\mathbf{x}_{1}, \dots, \mathbf{x}_{i}])\]as it iterates over elements of the input array. At each iteration it needs to adjust the normalizer to the new maximum $m_{i}$ and only then add new value to $\ell_{i}$.
There also exists a parallel version to fully utilize devices:
\[\begin{bmatrix} m(\mathbf{x}) \\ \ell(\mathbf{x}) \end{bmatrix} = \begin{bmatrix} \mathbf{x}_1 \\ 1 \end{bmatrix} \oplus \begin{bmatrix} \mathbf{x}_2 \\ 1 \end{bmatrix} \oplus \cdots \oplus \begin{bmatrix} \mathbf{x}_L \\ 1 \end{bmatrix},\]where the binary operation $\oplus: \mathbb{R}^2 \times \mathbb{R}^2 \rightarrow \mathbb{R}^2$ is defined as
\[\begin{bmatrix} m_i \\ \ell_i \end{bmatrix} \oplus \begin{bmatrix} m_j \\ \ell_j \end{bmatrix} = \begin{bmatrix} \max(m_i, m_j) \\ \ell_i e^{m_i - \max(m_i, m_j)} + \ell_j e^{m_j - \max(m_i, m_j)} \end{bmatrix}.\]The operation $\oplus$ is associative and commutative, which enables parallel and efficient evaluation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@jit
def online_softmax(logits):
def reducer(x, y):
m_i, l_i = x
m_j, l_j = y
m = jnp.maximum(m_i, m_j)
l = l_i * jnp.exp(m_i - m) + l_j * jnp.exp(m_j - m)
return (m, l)
m, l = jax.lax.reduce(
(logits, jnp.ones_like(logits)),
(-jnp.inf, 0.),
reducer,
(0,)
)
exp_logits = jnp.exp(logits - m)
return exp_logits / l
One can run this little test script to evaluate the efficiency of each implementation:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# create large random vector
logits = jax.random.uniform(random.PRNGKey(42), shape=(1_000_000,))
# one warmup run for each function to compile
naive_softmax(logits)
safe_softmax(logits)
online_softmax(logits)
print('Naive:')
%timeit naive_softmax(logits).block_until_ready()
print('\nSafe:')
%timeit safe_softmax(logits).block_until_ready()
print('\nOnline:')
%timeit online_softmax(logits).block_until_ready()
This is the output of the script running on TPU-v3:
1
2
3
4
5
6
7
8
Naive:
194 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Safe:
254 μs ± 17.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Online:
199 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Check out the original paper for more details, including algorithm with softmax + top-k fusion.
Lazy softmax
Now let’s get back to calculation of attention operation. Given query, key and value tensors $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{L \times d}$ our aim is to compute (we omit masking and $\frac{1}{\sqrt{d}}$ normalization here and after for simplification)
\[\mathbf{S} = \mathbf{Q}\mathbf{K}^T, \quad \mathbf{P} = \operatorname{softmax}({\mathbf{S}}), \quad \mathbf{O} = \mathbf{PV}.\]The straight-forward implementation would be
\[\mathbf{S}_{ij} = \mathbf{Q}_i^T\mathbf{K}_j, \quad \mathbf{P}_{ij}= \frac{e^{\mathbf{S}_{ij}}}{\sum_{l=1}^L e^{\mathbf{S}_{il}}},\quad \mathbf{O}_{i}= \sum_{l=1}^L \mathbf{P}_{il} \mathbf{V}_l \quad \forall i, j = 1, \dots, L\]The problem with implementation above is that it requires us to first compute and remember $\mathbf{S}_{ij}$ for all $j$, leading to linear time and memory complexity for each query, leading to the overall time and space complexity $\mathcal{O}(L^2)$. Rabe and Staats (2022) suggested to move the division by normalization term to the very end of the attention operation using the distributive law:
\[\mathbf{O}_{i}= \frac{ \sum_{l=1}^L \mathbf{V}_l e^{\mathbf{S}_{il}} } {\sum_{l=1}^L e^{\mathbf{S}_{il}}} \quad \forall i = 1, \dots, L.\]This implementation, called lazy softmax, can be computed with constant memory for each query: we start from vector $\mathbf{v}_0 \in \mathbb{R}^d$ and scalar $\ell_0$, both initialized with $0$, and when we process key/value pairs sequentially for $j=1, \dots, L$, we only update
\[\begin{aligned} \mathbf{v}_j &\leftarrow \mathbf{v}_{j-1} + \mathbf{V}_j e^{\mathbf{S}_{ij}}, \\ \ell_j &\leftarrow \ell_{j-1} + e^{\mathbf{S}_{ij}}. \end{aligned}\]After processing all keys and values, we divide $\frac{\mathbf{v}_L}{\ell_L}$ to get the final result.
One can notice that such algorithm has the same numerical problem as the naive implementation of softmax: incremental computation of the sum of exponentiated scores (and values). The standard safe-softmax trick cannot be applied here as the maximum may depend on the last score in the sequence. The subtraction cannot be delayed either, since the scores must be exponentiated before they can be added to the cumulative sum.
To resolve this problem, authors introduce an additional scalar $m$ as in online softmax, which keeps track of the maximum score that the incremental algorithm has seen so far, and they renormalize the sums of exponentiated values as needed:
\[\begin{aligned} m_j &\leftarrow \max(m, \mathbf{S}_{ij}), \\ \mathbf{v}_j &\leftarrow \mathbf{v}_{j-1} e^{m_{j-1}-m_j} + \mathbf{V}_j e^{\mathbf{S}_{ij} - m_j}, \\ \ell_j &\leftarrow \ell_{j-1} + e^{m_{j-1}-m_j}. \end{aligned}\]Authors also exploited massive parallelism and provided code in Jax for memory-efficient parallel algorithm. Notice that they use jax.checkpoint
decorator in summarize_chunk
function. The reason is that during forward pass this algorihtm saves memory by summarizing parts of the attention matrix sequentially, allowing it to forget the parts of the attention matrix it has summarized already. A naive application of differentiation would have to store all those intermediate results and algorithm would loose its memory advantage entirely. So authors propose to apply gradient checkpointing to the function that summarizes the individual chunks. The intermediate results can thus be forgotten during the forward pass and recomputed during backpropagation.
Applying checkpointing to the standard attention algorithm would not achieve these results. The standard attention algorithm with checkpointing would forget the attention matrix after it is formed; query chunk attention algorithm never forms the full attention matrix at all.
FlashAttention
FlashAttention might be the most popular implementation of attention mechanism nowadays. While it actually does more FLOPs than standard attention, it runs up to 3× faster just by making attention algorithm IO-aware — accounting for reads and writes between levels of GPU memory. Remember, how we discussed GPU architecture in the first section of this post and that moving tensors from SRAM can be 10× faster on modern GPU than moving them from HBM.
Standard attention forward pass looks like that:
- Load $\mathbf{Q}$, $\mathbf{K}$ by blocks from HBM, compute $\mathbf{S}=\mathbf{QK^T}$, write $\mathbf{S}$ to HBM.
- Read $\mathbf{S}$ from HBM, compute $\mathbf{P} = \operatorname{softmax}(\mathbf{S})$, write $\mathbf{P}$ to HBM.
- Load $\mathbf{P}$ and $\mathbf{V}$ by blocks from HBM, compute $\mathbf{O} = \mathbf{PV}$, write $\mathbf{O}$ to HBM.
- Return $\mathbf{O}$
Dao et al. (2022) modified it with two techniques to reduce the amount of HBM accesses to sub-quadratic in $L$:
Tiling: inputs $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$ are divided into blocks to fit to SRAM, then online softmax is computed to get attention scores for each block and the results are combined all together. With tiling the whole attention mechanism can be implemented in one CUDA kernel:
- Load inputs from HBM
- Perform all attention computation steps: $\mathbf{QK}^T$ matrix multiply, softmax, optionally masking and dropout, $\mathbf{PV}$ matrix multiply
- And only then write the result back to HBM
Recomputation:
Training requires to store intermediate outputs, such as $\mathbf{S}, \mathbf{P} \in \mathbb{R}^{L \times L}$, to compute gradients with respect to $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$ during backward pass. The standard mechanism to reduce memory footprint during training is to use gradient checkpointing: forgetting these outputs in forward pass and recalculating them in backward pass. However, this implementation has to trade speed for memory.
Authors propose to use selective gradient checkpointing and to store only the output $\mathbf{O}$ and the softmax normalization statistics $(m, \ell)$ to recompute the attention matrices $\mathbf{S}$ and $\mathbf{P}$ easily in the backward pass from blocks of $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$ in SRAM.
FlashAttention forward pass:
- Set block sizes $B_{\mathbf{KV}} = \lceil \frac{M}{4d} \rceil$, $B_{\mathbf{Q}} = \min \big( \lceil \frac{M}{4d} \rceil, d \big)$, where $M$ is an on-chip SRAM size.
- Initialize $\color{#E86456}{\mathbf{O} \in \mathbb{R}^{L \times d}}$, $\color{#E86456}{\ell \in \mathbb{R}^{L}}$ both $0$-valued and $\color{#E86456}{m \in \mathbb{R}^L}$ with values set to $\mathbf{-\infty}$, all stored in HBM.
- Split $\color{#E86456}{\mathbf{Q}}$ into $T_{\mathbf{Q}} = \lceil \frac{L}{B_{\mathbf{Q}}} \rceil$ blocks and split $\color{#E86456}{\mathbf{K}, \mathbf{V}}$ into $T_{\mathbf{KV}} = \lceil \frac{L}{B_{\mathbf{KV}}} \rceil$ blocks along sequence axis.
- Split $\color{#E86456}{\mathbf{O}, \ell, m}$ into $T_{\mathbf{Q}}$ blocks along sequence axis.
- For $j = 1, \dots, T_{\mathbf{KV}}$:
- Load $\color{#65AD69}{\mathbf{K}_j, \mathbf{V}_j}$ from HBM to SRAM
- For $i = 1, \dots, T_{\mathbf{Q}}$:
- Load $\color{#65AD69}{\mathbf{Q}_i, \mathbf{O}_i, \ell_i, m_i}$ from HBM to SRAM.
Compute unnormalized attention scores
\[\color{#65AD69}{\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}^T_j \in \mathbb{R}^{B_{\mathbf{Q}} \times B_{\mathbf{KV}}}}.\]- Compute
- Renew statistics
- Write \(\color{#E86456}{\mathbf{O}_i} \color{#EDA137}{ \leftarrow } \color{#65AD69}{\operatorname{diag}(\ell_i^{\text{new}})^{-1} \big( \operatorname{diag}(\ell_i) e^{m_i-m_i^{\text{new}}} \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\mathbf{P}}_{ij} \mathbf{V}_j \big) }\) to HBM.
- Write $\color{#E86456}{\ell_i} \color{#EDA137}{ \leftarrow } \color{#65AD69}{ \ell_{i}^{\text{new}}}$, $\color{#E86456}{m_i} \color{#EDA137}{ \leftarrow } \color{#65AD69}{m_{i}^{\text{new}}}$ to HBM.
- Return $\color{#E86456}{\mathbf{O}}$.
In the end FlashAttention alogithm returns $\mathbf{O} = \operatorname{softmax}(\mathbf{QK}^T)\mathbf{V}$ with $\mathcal{O}(L^2 d)$ FLOPs and requires $\mathcal{O}(L)$ additional memory beyond inputs and output. In terms of memory accesses, it requires $\mathcal{O}(L^2d^2M^{-1})$ HBM accesses, where $d \leq M \leq Ld$, compared to $\mathcal{O}(Ld + L^2)$ for standard attention.
Schematic diagram of how FlashAttention forward pass is performed, when $\mathbf{Q}$ is partitioned into $T_{\mathbf{Q}}=1$ block of size $B_{\mathbf{Q}} \times d$ with $B_{\mathbf{Q}}=3$ and $\mathbf{K}/\mathbf{V}$ are partitioned into $T_{\mathbf{KV}} = 2$ blocks of size $B_{\mathbf{KV}} \times d$ with $B_{\mathbf{KV}}=2$ each. Here $\ell_1=\sum e^{\mathbf{S}_{11}}$, $\ell_2=\ell_1 + \sum e^{\mathbf{S}_{12}}$. The step with subtracting $m$ in softmax is omitted for simplification.
Authors of FlashAttention also compared it to query chunk attention algorithm, stating three major differences:
- Memory-efficient attention focuses on memory footprint, while FlashAttention focuses on reducing HBM accesses.
- Memory-efficient attention summarizes each block with its temporary output along with the softmax normalization statistics. FlashAttention instead incrementally updates the output after processing each block, so only one copy of the output is needed.
- Memory-efficient attention uses gradient checkpointing to recompute the attention matrix and the temporary output of each block. FlashAttention only recomputes the attention matrix and does not recompute the temporary output of each block.
“FlashAttention” is an amazingly written paper and it’s definitely worth reading to anyone who’s training large language models. There are more details in the paper about FlashAttention backward pass, theoretical proofs and comparisons to other optimization algorithms.
FlashAttention + Parallelism
FlashAttention significantly speeds up attention computation and reduces memory usage from quadratic to linear in sequence length. While it works for most cases, it’s not optimized for the case of long sequences with small batch size and/or small number of attention heads, due to insufficient parallelism.
The first version of FlashAttention kernel uses one thread block per one attention head leading to overall $Bh$ thread blocks ($B$ is batch size, $h$ is number of attention heads). Each thread block is scheduled to run on a SM, and such scheduling is only efficient when $Bh$ is as large as number of SMs (e.g. 108 SMs for A100 GPU) for all of the compute resources to be used effectively. If one trains LLM with modern parallelism techniques batch size is reduced by a factor of DP and number of heads is reduced by a factor of TP.
Tri Dao, author of the original FlashAttention, applied additional parallelization along sequence axis to make better use of the multiprocessors on the GPU. In such regime in forward pass one attention head is now processed by multiple thread blocks and each block takes care of its own segment of rows of the attention matrix. As the rows of the attention matrix don’t depend on each other, we don’t need to communicate between the blocks.
In the backward pass each thread block now takes care of a segment of columns of the attention matrix. Parallelizing by columns is faster than parallelizing by rows due to the reduced communication between the workers (parallelizing by columns requires aggregating the gradient of the query, while parallelizing by rows requires aggregating the gradient of the key and value).
FlashAttention-2
A new version of FlashAttention (Tri Dao, 2023) included parallelization across different thread blocks to increase occupancy for long sequences. Besides that two numbers were reduced:
- the number of non-matmul FLOPs,
- the number of communication through SRAM.
The first tweak is to rewrite the online softmax trick to reduce the number of rescaling operations as well as bound-checking and causal masking, without changing the output (remember that matmul throughput can be few orders higher than non-matmul throughput).
The second tweak is an optimal work partitioning between warps, a group of threads working together. In the first FlashAttention $\mathbf{K}$ and $\mathbf{V}$ were divided across 4 or 8 warps per thread block. Each warp multiplies to get a slice of $\mathbf{QK}^T$, then they need to multiply with a slice of $\mathbf{V}$ and communicate to add up the result. However, this is inefficient since all warps need to write their intermediate results out to shared memory, synchronize, then add up the intermediate results.
In FlashAttention-2, $\mathbf{Q}$ is divided instead across 4 warps while keeping $\mathbf{K}$ and $\mathbf{V}$ accessible by all warps. After each warp performs matrix multiply to get a slice of $\mathbf{QK}^T$, they just need to multiply with their shared slice of $\mathbf{V}$ to get their corresponding slice of the output. There is no need for communication between warps. The reduction in SRAM reads/writes yields speedup
The second version also introduced support for larger head dimensions ($d = 128 \rightarrow 256$) and support for MQA and GQA.
FlashAttention-3
The most recent version, FlashAttention-3 (2024) focused on optimizations specific to H100 GPU, as previous versions only achieved up to 35% utilization with Hopper architecture. Authors exploited a new feature: tensor memory accelerator (TMA) - a new chunk of hardware that can do asynchronous address generation and fetch memory. The main techniques which were used for speed up:
- Overlap overall computation and data movement
- Interleave block-wise matmul and softmax operations
- Block quantization and incoherent processing that leverages hardware support for FP8 low-precision
Ring Attention
Even with Flash Attention, the memory complexity is linear in $L$ so scaling the sequence length is limited by the memory capacity. We could scale context with number of devices $N$, split inputs into $N$ parts, perform computations in parallel, then gather the results. However, attention requires for $\mathbf{Q}$ to access all elements of $\mathbf{K}, \mathbf{V}$ matrices.7 Sending large matrices between devices can add a huge communication overhead (e.g. A100 throughput is 600GB/s with NVLink and only 64GB/s with PCIe).
Ring Attention (Lie et al. (2023)) addresses this problem and explores the idea of hiding communication overhead behind computations in an extremely large context scenario. The algorithm is the following:
- Split input sequence into $N$ blocks along sequence axis, such that each device stores one input block of size $C = \lceil \frac{L}{N} \rceil$. Compute $\mathbf{Q}_i$, $\mathbf{K}_i$, $\mathbf{V}_i$ for its input block on each $i$-th device.
- For $\text{iter} = 0, \dots, N - 1$:
- For each $i$-th device in parallel:
- Let $j = (\text{iter} + i) \bmod N$.
- Compute memory-efficient attention incrementally using local $\mathbf{Q}_i$, $\mathbf{K}_j$, $\mathbf{V}_j$ blocks.
- Simultaneously, send $\mathbf{K}_j, \mathbf{V}_j$ blocks to the next device and receive new blocks from the previous device.
- For each $i$-th device in parallel:
GPUs are arranged in a ring, each holding a portion of $\mathbf{Q}$. During the Ring Attention process, the GPUs pass along blocks of $\mathbf{K}, \mathbf{V}$ to each other.
Local attention compute at each iteration requires $\mathcal{O}(C^2d)$ operations, while each device needs to send $\mathbf{K}_j, \mathbf{V}_j \in \mathbb{R}^{C \times d}$ tensors or $\mathcal{O(Cd)}$ bytes. Thus the lower bound of chunk to effectively hide communication overhead will be
\[C = \big \lceil \frac{L}{N} \big \rceil \geq \frac{\operatorname{compute performance}}{\operatorname{communication bandwidth}}.\]Authors also shared implementation code in JAX.
Stripe Attention
Brandon et al. 2023 studied Ring Attention performance for causal transformer and found out that it leads to highly imbalanced workload due to triangular structure of causal attention computations. They propose a simple extension called Stripe Attention to fix this and achieved near 1.5× speedup.
The problem of Ring Attention is that on all but the first iteration, the workload of some devices is entirely necessary (unmasked), while the workload of others is entirely unnecessary (masked) for the final output (and thus there is no need for them to be computed). The total latency is determined by the maximum latency of any participating device per iteration. As a result, regardless of per device optimizations, the latency per iteration would be the same as the time taken to compute a fully unmasked workload. As a result, Ring Attention will run as fast as a workload with no attention masking, despite in principle needing to compute only half the operations.
Workload distribution of Ring Attention (left) and Stripe Attention. The square matrices represent the set of all possible pairwise query/key interactions; row indices correspond to queries, and column indices correspond to keys. All cells above the diagonal are masked out by the causal mask and can be skipped. The colors indicate which devices are responsible for which parts of the computation. We can see that some devices in Ring Attention are responsible for workloads which are entirely masked out, whereas Striped Attention maintains a balanced workload across all devices
Rather than partitioning the tokens into contiguous blocks like in Ring Attention, Stripe Attention partitions them into sets of evenly spaced stripes based on their residues modulo $N$, so that $i$-th token would reside on $i \bmod N$ device. In practice, we can achieve this partitioning scheme by permuting the input tokens before the model’s first embedding layer, and then partitioning the permuted sequence into contiguous blocks as in Ring Attention. After partitioning, Stripe and Ring Attention algorithms proceeds almost identically.
Another recent improvement, called Tree Attention (Shyam et al. 2024), leveraged tree-reduction topology to reduce communication costs for decoding across multiple GPUs and achieved asymptotically 8× speedup compared to Ring Attention.
PagedAttention / vLLM
KV cache takes up a significant amount of memory. In a straightforward implementation we can store the KV cache of a request in contiguous memory space like all the other tensors. However, unlike the tensors in the traditional deep learning workloads, KV cache dynamically grows and shrinks over time as the model generates new tokens, and its lifetime and length are not known a priori.
To store the KV cache of a request in contiguous space, we have to pre-allocate a contiguous chunk of memory with the request’s maximum length, while the request’s actual length can be much shorter. Another memory inefficiency can be observed when we use advanced decoding techniques such as beam search or parallel sampling to generate multiple outputs per request. In these scenarios, the request consists of multiple sequences that can partially share their KV cache. However, memory sharing is not possible when KV cache is stored in separate contiguous spaces.
To address the above limitations, Kwon et al., 2023 propose PagedAttention, an algorithm inspired by the OS solution to memory fragmentation and sharing: virtual memory with paging. PagedAttention divides the request’s KV cache into blocks, each of which can contain the attention keys and values of a fixed number of tokens. In PagedAttention, the blocks for the KV cache are not necessarily stored in contiguous space and we can manage the KV cache in a more flexible way.
The key idea is to represent KV cache as a series of logical KV blocks, filled from left to right as new tokens are generated, that can be divided into non-contiguous physical KV blocks in GPU memory, allocated by a block engine. The KV block manager maintains block tables — the mapping between logical and physical KV blocks of each request. Separating logical and physical KV blocks allows vLLM to dynamically grow the KV cache memory without reserving it for all positions in advance.
Storing the KV cache of two requests with PagedAttention. The logical blocks of the two sequences are mapped to different physical blocks within the space reserved by the block engine in GPU workers. Block manager maps the first three logical blocks to 7, 1 and 3 physical blocks for request A and first two blocks to 5 and 2 for request B.
In parallel sampling multiple responses share the same input prompt, allowing the KV cache of the prompt to be shared as well. Thus the logical blocks for the prompts of all sequences are mapped to the same physical block, which stores a reference count. At the generation phase the first physical block which contains the response is replicated with reduced reference count and sampling continues with separate blocks. By sharing physical KV blocks across multiple samples, memory usage can be greatly reduced, especially for long input prompts.
Beam search decoding with PagedAttention is more advanced, since also blocks across diferent candidates can be shared and the sharing pattern changes dynamically as the decoding process progresses. As candidates no longer among top, their logical blocks are freed and the reference counts of corresponding physical blocks are reduced.
Additionaly to the algorithm authors released vLLM, a high-throughput distributed LLM serving engine on top of PagedAttention. Their evaluations on various models and workloads show that vLLM improves the LLM serving throughput by 2-4× compared to the other existing serving systems, without affecting the model accuracy.
Conclusion
This essay has covered a wide range of transformer inference optimization techniques, from high-level algorithmic improvements like KV caching and MQA/GQA, to low-level hardware optimizations such as CUDA kernel fusion and vLLM. The key takeaway is clear: LLMs are resource-intensive beasts, and taming them requires a diverse toolkit.
We’ve seen that successful ML engineering isn’t just about understanding algorithms or hardware - it’s about grasping how they work together. An algorithmic breakthrough might fall flat without the right hardware implementation, while a deep dive into GPU architecture could unlock unexpected performance gains.
For ML engineers, the message is simple: stay curious and be ready to learn across the entire stack. The most effective optimizations often come from connecting the dots between different areas of expertise. As LLMs continue to grow and evolve, so too will the tricks we use to run them efficiently.
Full names given to GPU architectures are: Turing, Volta, Ampere, Hopper, Blackwell. ↩
In general, the number of flops for a matrix-matrix multiplication $\mathbf{AB}$ with $\mathbf{A} \in \mathbb{R}^{n \times m}$ and $\mathbf{B} \in \mathbb{R}^{m \times k}$ is near $2mnk$: we perform $m$ matrix-vector multiplications, each of which can be represented as $n$ inner products, and each inner product requires $k-1$ additions and $k$ multiplications. ↩
Dimension size of values $\mathbf{V}$ can be different from $d$, but usually it’s not the case. ↩
A reasonable question might be: “What is the best way to utilize GPU to generate small sequences, e.g. $L \ll d$?” A possible solution is to enlarge batch processing, since one can compute that for $x \in \mathbb{R}^{B \times d}$ the arithmetic intensity is \(\frac{BL^2d + BLd^2}{BL^2h + BLd + d^2} \xrightarrow[d \gg L]{} BL\) ↩
While it is clear that hedgehog comes from attention “spikyness” modelling, I still wonder what porcupine in the title refers to. ↩
This isn’t unique to GPUs - in fact, TPUs are even less general than GPUs. ↩
If we don’t use SWA. ↩