Month 4 · Kernel & 性能
下到 CUDA 层。理解为什么 FlashAttention 这种"重写 attention"的东西能快 3 倍。
Month 1-3 一直在 Python 层。这个月下到 csrc/ 和 Triton。
目标不是能徒手写 CUDA 算子,而是能读懂、能调试、能给已有 kernel 加 benchmark。
这一层最难补,但回报是面试时立刻显出底子——多数候选人在这里说不出三句话,你能讲清 memory-bound、roofline、online softmax、CUDA graph、block table indirection,就已经在 top 10%。
softmax(QK^T / sqrt(d))V 谁都会写。
PyTorch 一行 F.scaled_dot_product_attention 就能跑。
那 FlashAttention 凭什么快 3 倍?它没改公式,到底改了什么?
附加题:为什么 vLLM 在 prefill 用 FlashAttention、在 decode 又写了自己的 PagedAttention kernel?
先猜,再展开答案
关键洞察:现代 GPU 的瓶颈不是算力,是显存带宽。
朴素 attention 要把 N×N 的 attention matrix 实例化在 HBM(GPU 主显存)。 对 N=4096,这个矩阵 16M 个 float,一来一回搬 64MB 数据。 SM (compute unit) 大部分时间在等显存,不在算。
FlashAttention 的做法:
- 把 Q、K、V 切成 tile,每次只把一对 tile 加载到 SRAM(GPU 内 L1 级缓存,几十 KB,巨快)。
- 在 SRAM 里算完这对 tile 的部分 attention,累加到输出。
- 用"online softmax"技巧让 tile-by-tile 计算等价于一次性算。
- 从头到尾 attention matrix 不实例化,只有最终 O 矩阵搬回 HBM。
本质 = IO-aware,工艺级地少搬数据。算力没省,但显存搬运省了一个数量级。
而 PagedAttention 多解了一个问题:KV 不再是连续显存,是一堆物理块。kernel 得能从 block table 聚集(gather)KV,FlashAttention 默认不支持。所以 vLLM 给 decode 写了专门的 paged kernel;prefill 阶段一个请求的 KV 还是连续的,能直接用 FlashAttention。
01核心论断:现代推理是 memory-bound
这一节是这章的"立论"。后面所有优化都是这个论断的必然推论。先把它建结实,后面读 kernel 不会迷路。
Arithmetic intensity 的定义
给定一个算子(比如 matmul),有两个可以测的量:
- FLOPs:算这个算子要做多少次浮点运算。
- Bytes loaded:从 HBM 读多少字节进 SM。
把它们的比值叫做 arithmetic intensity (AI):
AI = FLOPs / Bytes loaded 单位:FLOPs/byte
这个比值告诉你:每搬一字节数据,能"摊"多少次浮点运算。AI 越高,HBM 带宽就越能跟上算力——算力被用满;AI 越低,HBM 就跟不上——算力闲着干等。
A100 的"分水岭"
NVIDIA A100 (40GB SXM4) 的纸面参数:
| 指标 | A100 SXM4 | H100 SXM5 |
|---|---|---|
| 峰值 FP16 Tensor Core 算力 | 312 TFLOPs | ~989 TFLOPs (稠密) |
| HBM 容量 | 40 / 80 GB | 80 GB (HBM3) |
| HBM 带宽 | ~1.55 TB/s (40GB) · ~2.0 TB/s (80GB) | ~3.35 TB/s |
| L2 cache | 40 MB | 50 MB |
| 每 SM SRAM (shared memory + L1) | 192 KB | 228 KB |
| SM 数量 | 108 | 132 |
| NVLink 带宽 | 600 GB/s | 900 GB/s |
数据来自 NVIDIA A100/H100 数据手册。不同版本(PCIe / SXM)和不同 fp 精度(fp16 / bf16 / fp8)数字会有差异;本表用 SXM 版 + FP16 稠密 Tensor Core 作为代表数。
把 A100 (80GB) 的两个关键数字摆在一起:
峰值算力 = 312e12 FLOPs/s
HBM 带宽 = 2.0e12 bytes/s
分水岭 AI* = 算力 / 带宽 = 312 / 2 = 156 FLOPs/byte
如果 AI > 156:compute-bound,被算力卡住
如果 AI < 156:memory-bound,被带宽卡住
这就是roofline 分析的起点。同一段代码,部署在 A100 还是 H100,分水岭都不同。但结论形状一样:每张卡都有一个"压力点 AI*",跨过去叫 compute-bound,没跨过去叫 memory-bound。
Attention 的 AI 在 decode 阶段是 ~1
这是这门课要你背下来的一个数字。decode 阶段每生成一个 token:
- Q:1 个新 token 的 query,shape
[1, d]。 - K, V:迄今为止整个序列的 cache,shape
[N, d]。 - 计算
q @ K.T:N · d次乘加 ≈2NdFLOPs。 - softmax · @ V:再来一次
2NdFLOPs,总共~4Nd。 - 但要从 HBM 读 K + V,bytes =
2 · N · d · 2(fp16,每 element 2 bytes)=4Nd。
AI_decode = FLOPs / Bytes = 4Nd / 4Nd = 1 FLOP/byte
对比分水岭 156——decode attention 在 roofline 的极左下角,深度 memory-bound。说人话:你给 SM 多少算力都没用,HBM 跟不上,SM 大部分时间在等数据。
Prefill 的 AI 在 ~seq_len/8
prefill 时 Q 也是 [N, d](整段 prompt),所以矩阵乘是 N×N×d,FLOPs ≈ 2 N² d。bytes 读的还是 K+V 一份,~4 N d(如果能一次性留在 SRAM)。
AI_prefill ≈ 2 N² d / (4 N d) = N / 2 (理想)
实际算 + softmax 的开销,常被估为 N / 8 量级
对 N = 4096,AI_prefill ≈ 500——已经压上 compute roof。prefill 是 compute-bound 的。
这意味着 kernel 设计的两条铁律
- 能不读 HBM 就不读。能留在 SRAM 的就在 SRAM 完成。
- 能合并读就合并读。能一个 kernel 算完的事情,不要拆成两个 kernel(每次 HBM 来回都是浪费)。这就是"kernel fusion"。
FlashAttention 是这两条的极致执行。后面会逐步展开。
02GPU 内存层级 · 详图
这是最基础的硬件背景。看完这张图,FlashAttention 才能 click。每往下一层,容量大 10×、带宽小 10×、延迟高 10-100×。
关键数字:SRAM 比 HBM 快 10× 左右
把同一份 attention tile 留在 SRAM 重用 k 次,理论上能比"每次都从 HBM 读"快 ≈ min(k, 10) 倍——这就是 FlashAttention 速度上限的来源。
A100 vs H100:差距不只是 FLOPs
| 层级 | A100 80GB | H100 80GB | 差距来源 |
|---|---|---|---|
| FP16 Tensor Core | 312 TFLOPs | ~989 TFLOPs | Hopper 新 Tensor Core 单元 |
| FP8 Tensor Core | 不支持 | ~1979 TFLOPs | Hopper 首次原生 fp8 |
| HBM 带宽 | 2.0 TB/s | 3.35 TB/s | HBM3 vs HBM2e |
| 分水岭 AI* | ~156 | ~295 (fp16) | 算力涨得比带宽快 |
| SRAM/SM | 192 KB | 228 KB | tile 可以更大 |
| TMA (异步拷贝) | 无 | 有 | tile 加载与计算重叠 |
| WGMMA | 无 | 有 | warp-group 级矩阵乘指令 |
注意 H100 的 AI* 涨到了 ~295。这意味着很多在 A100 上属于 compute-bound 的算子,在 H100 上变回 memory-bound 了——优化的重心继续往"省带宽"倾斜。FlashAttention v3 专门吃 Hopper 的 TMA + WGMMA,就是为了把屋顶进一步推高。
03GPU 执行模型 · 必知
读 CUDA 之前必须把这套术语建好。thread / warp / block / SM / grid 是层层嵌套的概念,搞混了所有代码都看不懂。
层级一览
__syncthreads() 只在 block 内同步,跨 block 同步只能靠 cooperative_groups 或重新 launch kernel。SIMT:lockstep 是默认假设
"SIMT" = Single Instruction, Multiple Thread。一组 32 个 thread (一个 warp) 共用一个指令指针。所有 thread 在同一时刻执行同一条指令,只是各自处理不同数据。
这跟 CPU 的 SIMD 很像,但 CUDA 把它包装得像每个 thread 是独立的——所以你能写:
// 每个 thread 处理一个数组元素
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = a[idx] + b[idx];
}
看起来是"per-thread 独立",但实际硬件是 warp 级 lockstep。如果 warp 里 32 个 thread 走了不同分支(warp divergence),硬件会先执行 if 分支(mask 掉走 else 的 thread),再执行 else 分支——总时间 = 两条路径之和。这是 CUDA 性能调优的常见陷阱。
BLOCK_SIZE 总是取 32 的倍数——否则最后一个 warp 有"空 thread",浪费算力。
Occupancy:SM 上同时驻留多少 thread
每个 SM 有固定的资源:寄存器、SRAM、最大 thread 数。一个 block 用得越多,同一 SM 能装的 block 越少。占有率定义:
occupancy = 实际驻留的 warp 数 / SM 支持的最大 warp 数
A100 每 SM 最多 64 warp (2048 thread)。如果你的 block 要 256 thread,每 thread 要 64 寄存器:
- 寄存器约束:256 × 64 = 16384 寄存器/block,A100 每 SM 65536 寄存器 → 最多 4 个 block。
- SRAM 约束:取决于
__shared__用量。 - thread 总数约束:4 block × 256 = 1024 thread → occupancy = 1024 / 2048 = 50%。
低 occupancy 不一定坏(FlashAttention 就有意降低 occupancy 来增大 per-thread SRAM),但通常需要刻意设计。
04Attention 复习:朴素 PyTorch 实现
把"为什么 attention 是 memory-bound"用具体数字算一遍。
5 行朴素代码
# shapes: Q, K, V each [B, H, N, d] (batch, heads, seq, head_dim)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d) # [B, H, N, N] ← 巨大
probs = torch.softmax(scores, dim=-1) # [B, H, N, N]
out = probs @ V # [B, H, N, d]
看似清爽,但这里有一个 N×N 的临时矩阵。对 N = 4096,每个 fp16 元素 2 字节:
scores 大小 = B · H · N · N · 2 字节
若 B=1, H=32, N=4096:
= 1 · 32 · 4096 · 4096 · 2
= ~1.1 GB ← 一次性占的 HBM
还要再写一份 probs (softmax 结果) → 再 1.1 GB
每步的 FLOPs 与 HBM 流量
| 步骤 | FLOPs | HBM read | HBM write |
|---|---|---|---|
| scores = Q @ K.T | 2 B H N² d | 读 Q, K: 4 B H N d | 写 scores: 2 B H N² |
| softmax | ~5 B H N² | 读 scores: 2 B H N² | 写 probs: 2 B H N² |
| out = probs @ V | 2 B H N² d | 读 probs + V: 2 B H N² + 2 B H N d | 写 out: 2 B H N d |
对 B=1, H=32, N=4096, d=128, fp16 (2 byte/element),HBM 总流量:
读+写 ≈ 4 B H N d (Q+K) ~67 MB
+ 6 B H N² (scores/probs) ~3.2 GB
+ 2 B H N d (V) ~33 MB
+ 2 B H N d (out) ~33 MB
─────────────────
~3.4 GB
对比同样 shape 下 FlashAttention 的 HBM 流量 = 读 Q,K,V + 写 O ≈ ~170 MB。差距 ~20×——这就是 FlashAttention 的"快 3 倍"在算什么。
注意 FLOPs 是一样的。FlashAttention 没有省一次乘加,它只省 HBM 搬运。这就是 IO-aware 的含义。
05FlashAttention · The Idea
FlashAttention (Dao et al., NeurIPS'22) 的核心 trick 概括成一句话:把 attention matrix 分块,每块在 SRAM 算完直接累加进输出,不把完整矩阵写回 HBM。
第一步:分块的设想
把 Q 按 row 切成 Br-大小的块(比如 Br=128),把 K, V 按 row 切成 Bc-大小的块(比如 Bc=64)。共有 N/Br 个 Q 块和 N/Bc 个 KV 块。
对每个 (Q_i, K_j, V_j) 三元组:
- 把 Q_i, K_j, V_j 加载进 SRAM。
- 计算
S_ij = Q_i @ K_j.T(小矩阵,Br × Bc)。 - 对 S_ij 做局部 softmax,得到 P_ij。
- 累加
O_i += P_ij @ V_j。
朴素思路问题:softmax 不是局部可分的。softmax(x_i) = exp(x_i) / Σ_j exp(x_j) 里那个 Σ 要全局——你不知道 j 这一块的 exp 之前要不要乘个修正因子。
第二步:online softmax —— 让 softmax 变成可累加的
关键洞察:softmax 可以"边走边修"。维护两个 running 变量:
m= 截至目前看到的最大值(global max)l= 截至目前的Σ exp(x_j - m)(归一化分母)O= 截至目前的输出累加
每来一个新 block,先算新 block 的 local max m_new、local sum l_new,然后修正旧累加值:
m_combined = max(m_old, m_new)
l_combined = exp(m_old - m_combined) · l_old + exp(m_new - m_combined) · l_new
O_combined = exp(m_old - m_combined) · O_old + exp(m_new - m_combined) · (P_new @ V_new)
关键是那个 exp(m_old - m_combined) 因子——它把"按旧 max 归一化的旧 O"重新校正到新 max 下的尺度。最终 O 在所有 block 都过一遍后,等于全局 softmax 的结果,分毫不差。
FlashAttention 的 tile schedule
HBM 流量量级
FlashAttention 论文给的渐进结果:
| 方法 | HBM 读写 | N=4096, d=128 数量级 |
|---|---|---|
| 朴素 | Θ(N · d + N²) | ~3.4 GB |
| FlashAttention | Θ(N² · d² / M) = Θ(N · d) 当 d²/M 小 | ~170 MB |
实际加速比:常见 3-7×(GPT-2/3 量级 attention 层)。
06FlashAttention v1 → v2 → v3
三个版本不是"更新换代"那么简单。每一代解决不同硬件层面的瓶颈。
| v1 (NeurIPS'22) | v2 (2023) | v3 (2024, Hopper-only) | |
|---|---|---|---|
| 外层循环 | Q tile (按行) | K/V tile (按列) | 同 v2 |
| 内层循环 | K/V tile | Q tile | 同 v2 + 异步 pipeline |
| warp 划分 | 每 warp 独立计算 + reduce | warp 协同,更少同步 | WGMMA: 一个 warp-group 一次大矩阵乘 |
| 异步搬运 | 无 | 无 | TMA 异步加载 tile |
| FP8 支持 | 无 | 无 | 有 (Hopper 原生) |
| 典型加速 (over v1) | 1× | ~2× | 再 ~1.5-2× (在 H100 上) |
为什么 v2 要换循环顺序
v1 的痛点:K/V tile 在不同 Q tile 之间被反复加载。如果 Q 有 32 个 tile,每个 K tile 要被 load 32 次。v2 反过来:外层 K/V,内层 Q。这样每个 K tile 只 load 一次,Q tile 在 SRAM 里多次复用(Q 是只读,不像 O 还要累加)。
v2 还重排了 warp 内的 work partition。v1 让每个 warp 算一行 attention,多个 warp 间要做 reduce;v2 让 warp 协作算同一行的不同列,reduce 在 warp 内用 shuffle 完成,避免 block 内同步。
v3 吃 Hopper 新硬件
- TMA (Tensor Memory Accelerator):硬件级的"DMA",加载 tile 时不占用 SM 算力,跟计算重叠。等于流水线。
- WGMMA (Warp-Group MMA):一个 warp-group (4 warp) 一次做一大块矩阵乘,吃满 fp16/fp8 Tensor Core。
- FP8 路径:Hopper 原生 fp8 Tensor Core 算力是 fp16 的 2 倍。v3 提供 fp8 attention 路径(精度损失可控)。
vLLM 怎么选
vLLM 通过 flash-attn 依赖间接使用 FA v2/v3。在 H100 上跑会自动用 v3 路径;A100 上是 v2。vLLM 自己不维护 FA 的实现——这是个上游依赖。但 vLLM 控制了:
- 什么时候调用 FA(prefill、prefix 重计算、chunked prefill 等)。
- 什么时候不用 FA 用自己的 paged kernel(decode 阶段,因 KV 非连续)。
VLLM_ATTENTION_BACKEND=XFORMERS 偶尔会胜出——所以 vLLM 还保留多个 backend 让你选。
07PagedAttention Kernel · 跟 FlashAttention 的区别
M2 已经讲过 PagedAttention 的memory layout。这一节聚焦它的 kernel——和 FlashAttention 的根本区别。
FlashAttention 默认假设:KV 是连续的
FA 的 tile schedule 是按 row offset 走的——K[j_start : j_end]。这要求 K, V 在显存里是整段连续。对单个 prefill 请求成立(一个 sequence 的 KV 一气分配出来),但对 vLLM 的 decode 完全不成立——一个 sequence 的 KV 被切成 16-token block 散落在 BlockPool 里。
PagedAttention kernel 怎么处理
block_table 间接查找。每个物理 block 内部还是连续的 16 个 token,所以 SRAM 里的小 attention 计算跟 FA 没区别——区别只在外层怎么找到这些 block。核心差异 = "gather" pattern
- FlashAttention:
k_tile = K[start:end]——连续 load,硬件吃满 coalesced access。 - PagedAttention:
phys = block_table[i]; k_tile = K_pool[phys]——间接 load。每次要先读 block_table 一项(小,cache 友好),再去 K_pool 里跳着读。
这"跳着读"看起来很贵。实际上:
- block_table 通常很小(每个 sequence 几十项),整表能放进 L2 / L1。
- 一个 block 内部 16 个 token 的 K 连续,大批量数据还是顺序读。
- 只是"块与块之间"非顺序——损失在 L2 prefetch 的有效性,不在 HBM 带宽本身。
实测 PagedAttention kernel 比连续 KV 的 FA 慢 ~5-15%(vLLM 论文报告)。这个开销是为了换取显存利用率从 ~40% 到 ~96%——稳赚。
为什么 PagedAttention 主要在 decode 用
区别在每个 kernel call 有多少 query:
| prefill | decode | |
|---|---|---|
| 每次 forward 的 query 数 | N (整段 prompt, ~千) | 1 (只生成一个新 token) |
| KV 是否连续 | 是 (一个请求的 KV 一次性写) | 否 (KV 跨 sequence 累积) |
| 合适 kernel | FlashAttention | PagedAttention |
| 瓶颈 | compute-bound | memory-bound |
当 sequence 跨 batch 混着 prefill 和 decode(chunked prefill 场景),vLLM 会选择性 dispatch——这是 attention backend 抽象层的工作。
08vLLM Kernel 生态全景
vLLM 不写自己的所有 kernel。生态分工:
| 组件 | 来源 | 位置 / 状态 |
|---|---|---|
| 朴素 attention (prefill, 长 ctx) | FlashAttention v2/v3 (Dao-AILab) | 外部依赖 flash-attn |
| 更花的 attention (FP8, sparse) | FlashInfer (CMU) | 外部依赖 flashinfer |
| Paged attention (decode) | vLLM 自写 | csrc/attention/paged_attention_v{1,2}.cu |
| Triton attention (跨架构) | vLLM 自写 | vllm/attention/ops/ 或 vllm/v1/attention/ops/ |
| RMSNorm | vLLM 自写 | csrc/layernorm_kernels.cu |
| RoPE (rotary position) | vLLM 自写 | csrc/pos_encoding_kernels.cu |
| SwiGLU / GeGLU 激活 | vLLM 自写 | csrc/activation_kernels.cu |
| GEMM (大矩阵乘) | cuBLAS / cuBLASLt | NVIDIA 闭源 |
| 量化 GEMM (W4A16, FP8) | Marlin / Machete (社区) + vLLM 集成 | csrc/quantization/ |
| MoE expert kernel | vLLM + 社区 | vllm/model_executor/layers/fused_moe/ |
| AllReduce / AllGather | NCCL | NVIDIA 开源 |
"vLLM 自写"最多的是 paged attention 系列——因为只有它需要 block table indirection,没有库直接提供。其次是一些小算子(RMSNorm、RoPE、激活),vLLM 把它们 fuse 起来减少 kernel launch。
Backend 选择逻辑
vLLM 在 runtime 决定用哪个 attention backend。可以用环境变量强制:
# 强制 FlashAttention (v2 或 v3,看硬件)
VLLM_ATTENTION_BACKEND=FLASH_ATTN python -m vllm.entrypoints.openai.api_server ...
# 强制 vLLM 自己的 paged kernel
VLLM_ATTENTION_BACKEND=PAGED_ATTN ...
# 用 FlashInfer
VLLM_ATTENTION_BACKEND=FLASHINFER ...
# AMD 卡上:用 Triton 路径
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 ...
不设环境变量时,vLLM 按"硬件 + 操作"自动选——见 vllm/attention/selector.py(v1 引擎下是 vllm/v1/attention/backends/)。规则大致:
- H100 + 长 ctx + prefill:FlashAttention v3。
- A100 + prefill:FlashAttention v2。
- 任意硬件 + decode:自家 paged kernel(要 indirection)。
- AMD MI300:Triton 路径(因为 CUDA kernel 不能跨编译)。
09代码深读 · paged_attention_v1.cu
这是本月最硬核的一段。看不懂不要怕,目标只是认出结构。
整体导航
paged_attention_v1.cu· 老版本 paged attention kernel,一个 thread block 处理一个 (sequence, head) 对。paged_attention_v2.cu· 引入 partition——单个 sequence 太长时切多段并行处理,最后再 reduce。attention_kernels.cuh· 共享的 utility(warp reduce、float vec load 等)。attention_dtypes.h· 数据类型 traits (fp16, bf16, fp8 的 SIMD 包装)。attention_generic.cuh· 通用宏。
kernel 签名(简化版)
实际签名很长。下面是按可读性整理的代表形式:
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] 输出
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] 每序列 1 个 query
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads,
const float scale, // 1 / sqrt(head_size)
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] · 每序列实际长度
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // 可选 ALiBi 偏置
const int q_stride,
const int kv_block_stride,
const int kv_head_stride
);
thread block 组织
关键决策:
- grid 维度:
(num_seqs, num_heads)——每 (sequence, head) 对一个 thread block。一个 block 内部处理这一对的整个 attention 计算。 - block 大小:通常 128 thread(4 warp)。
- thread 内分工:所有 thread 协作覆盖
head_size维度。head_size=128 时每 thread 处理 1 个 dim。 - iteration 单位:每次迭代处理一个 KV block (16 token)。block 内 warp 协作算这 16 个 query-key 点积。
主循环骨架(伪代码)
// 1. 每个 thread block 负责 (seq_idx, head_idx)
const int seq_idx = blockIdx.x;
const int head_idx = blockIdx.y;
const int thread_idx = threadIdx.x;
// 2. 把这个 head 的 query 加载到 SRAM (寄存器)
__shared__ scalar_t q_shared[HEAD_SIZE];
if (thread_idx < HEAD_SIZE) {
q_shared[thread_idx] = q[seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + thread_idx];
}
__syncthreads();
// 3. 初始化 online softmax 累加器
float m_acc = -INFINITY; // running max
float l_acc = 0.0f; // running sum
float o_acc[HEAD_SIZE/NUM_THREADS] = {0.0f}; // running output (per-thread slice)
// 4. 主循环:遍历这个 sequence 的所有 KV block
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int num_blocks = (context_lens[seq_idx] + BLOCK_SIZE - 1) / BLOCK_SIZE;
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
// 4a. 查 block table 拿物理块号
const int physical_block = block_table[block_idx];
// 4b. 计算这一块的 q @ K_block.T (BLOCK_SIZE 个分数)
float scores[BLOCK_SIZE];
for (int token_offset = 0; token_offset < BLOCK_SIZE; ++token_offset) {
// 每个 thread 负责 head_size 的一个 chunk
float partial = 0.0f;
for (int d = thread_idx; d < HEAD_SIZE; d += NUM_THREADS) {
partial += q_shared[d] *
k_cache[physical_block * stride + token_offset * BLOCK_SIZE + d];
}
// warp reduce 把所有 thread 的 partial 加起来
partial = warp_reduce_sum(partial);
scores[token_offset] = partial * scale;
}
// 4c. online softmax 更新
float block_max = -INFINITY;
for (int t = 0; t < BLOCK_SIZE; ++t) block_max = max(block_max, scores[t]);
float m_new = max(m_acc, block_max);
float rescale = exp(m_acc - m_new);
float l_new = rescale * l_acc;
for (int t = 0; t < BLOCK_SIZE; ++t) {
scores[t] = exp(scores[t] - m_new); // 归一化到新 max
l_new += scores[t];
}
// 4d. 用 scores 加权累加 V_block 到 o_acc,同时 rescale 旧 o_acc
for (int d = thread_idx; d < HEAD_SIZE; d += NUM_THREADS) {
float new_val = 0.0f;
for (int t = 0; t < BLOCK_SIZE; ++t) {
new_val += scores[t] * v_cache[physical_block * stride + d * BLOCK_SIZE + t];
}
o_acc[d / NUM_THREADS] = rescale * o_acc[d / NUM_THREADS] + new_val;
}
m_acc = m_new;
l_acc = l_new;
}
// 5. 最后归一化 + 写回 HBM
__syncthreads();
for (int d = thread_idx; d < HEAD_SIZE; d += NUM_THREADS) {
out[seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + d] = o_acc[d / NUM_THREADS] / l_acc;
}
以上是代表性骨架,跟仓库里实际代码细节不完全对齐(fp16 vs fp32 累加、向量化 load、ALiBi 支持等省略了)。但循环结构和 online softmax 是真的。
关键点说明
- __shared__ 修饰:定义在 SRAM 上的 buffer。query 加载一次后 block 内所有 thread 复用。
- warp_reduce_sum:用
__shfl_sync在 warp 内做 reduction,不需要 SRAM round-trip。 - per-thread o_acc:每个 thread 持有 head_size 的一个 slice(在自己的寄存器里)。最后没有跨 thread reduce——因为 output 也是按 dim 分给 thread 写的。
- online softmax 的 rescale:
rescale = exp(m_old - m_new)把旧累加值校正到新 max 的尺度。
csrc/attention/paged_attention_v1.cu · paged_attention_v1_kernel__shared__修饰的变量 = SRAM 上的 buffer。看用了多少 KB。- 外层 for 循环遍历 KV block,里面 for 遍历 token in block。
block_table索引把逻辑 block → 物理 block 偏移。- online softmax 的 max 和 sum 累加变量(通常叫
qk_max、exp_sum)。 __shfl_sync调用——warp reduce 在哪几个点发生。
csrc/torch_bindings.cpp 或 csrc/pybind.cppTORCH_LIBRARY_FRAGMENT 或 PYBIND11_MODULE,找到 paged_attention_v1 的注册行。
tensor.data_ptr(); 同时 PyTorch 替你管 stream)10CUDA Graph · 减少 launch overhead
这一节是 vLLM 性能的"隐藏 boost"——很多人不知道它的存在,但 decode 阶段 5-15% 的提速就靠它。
问题:每个 kernel launch 要多少时间
从 CPU 发起一次 CUDA kernel launch,typical cost:
- CPU 侧:5-10 µs(参数 marshalling、driver 调用、写 command buffer)。
- GPU 侧:0.5-1 µs(warm-up + 调度 block 到 SM)。
对一个吃几百微秒的 kernel,这点 overhead 微不足道。但decode 阶段不一样:
一次 decode iteration ≈ 50 个 kernel (RMSNorm + Q/K/V proj + attention + O proj + FFN × N层)
launch overhead ≈ 50 × 7 µs ≈ 350 µs (CPU 端)
对一个 ~10 ms 的 decode step,350 µs / 10 ms = 3.5% 开销
对 7B 模型在 H100 上跑 decode (~3 ms/step),开销 ≈ 11%
模型越小、kernel 越多、单 kernel 越快——launch overhead 比例越高。
解决:CUDA Graph capture + replay
vLLM 怎么用 CUDA graph
关键问题:CUDA graph 要求每次 replay 的 shape 一致(kernel launch 参数是 capture 时定死的)。但 vLLM 的 batch size 动态变化——这就要"shape bucketing"。
vLLM 在 engine 初始化时,对一组预定义的 batch size(比如 1, 2, 4, 8, 16, 32, ..., max_num_seqs)各 capture 一份 graph。decode 时根据当前 batch size 取最近的桶 padding 进去,replay。
capture_model / cuda_graph / graph_capture。看:
- 什么时候触发 capture (启动时?warmup 阶段?)
- 哪些 batch size 被 capture (
cudagraph_batch_sizes列表) - 怎么 pad 到下一个桶
--enforce-eager标志怎么禁用 graph
陷阱
- 非确定性输入 shape:要么 padding 要么 fall back 到 eager 模式。
- capture 耗显存:每个 batch size 桶 capture 一份要 ~50-200 MB workspace。桶多了显存吃紧。vLLM 默认 capture 一二十个桶。
- 第一次 capture 慢:vLLM 启动时间 +20-60 秒。这是为什么有
--enforce-eager的 debug 选项。
11Triton · 比 CUDA 友好的中间层
CUDA 上手太陡:手动管 shared memory、手动 warp reduce、手动 vectorize load。Triton 是 OpenAI 出的 Python-like DSL,编译成 CUDA / ROCm,把这些自动化了。
对比:vector add 在 CUDA vs Triton
CUDA 版(节选):
__global__ void vector_add_kernel(const float* a, const float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] + b[idx];
}
}
// host 侧 launch
int block_size = 256;
int grid_size = (n + block_size - 1) / block_size;
vector_add_kernel<<<grid_size, block_size>>>(a, b, c, n);
Triton 版(同功能):
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(
a_ptr, b_ptr, c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr, # 编译时常量
):
pid = tl.program_id(axis=0) # 当前 "block" 的 id
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # 这个 block 处理哪些 offset
mask = offsets < n_elements # 边界保护
a = tl.load(a_ptr + offsets, mask=mask) # 一次 load BLOCK_SIZE 个 element
b = tl.load(b_ptr + offsets, mask=mask)
tl.store(c_ptr + offsets, a + b, mask=mask)
# 调用
def vector_add(a, b):
c = torch.empty_like(a)
n = a.numel()
BLOCK_SIZE = 1024
grid = (triton.cdiv(n, BLOCK_SIZE),)
vector_add_kernel[grid](a, b, c, n, BLOCK_SIZE=BLOCK_SIZE)
return c
关键差异
| CUDA | Triton | |
|---|---|---|
| 抽象粒度 | per-thread | per-block (block 内 thread 编译器代办) |
| load 单位 | 每 thread 一个 element | 一个 vector (BLOCK_SIZE 个) |
| shared memory | 手动 __shared__ | 编译器自动决定 |
| warp reduce | 手动 __shfl_sync | tl.sum(x, axis=0) 一行 |
| autotune | 手写 + 跑 benchmark | @triton.autotune 自动 |
| 跨架构 | NVIDIA only | NVIDIA + AMD CDNA (实验性) |
| 调试 | printf, cuda-gdb | TRITON_INTERPRET=1 在 CPU 上 step |
softmax 在 Triton 里的样子
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(0) # 当前 block 处理第几行
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
# 一次 load 一整行 (前提:n_cols ≤ BLOCK_SIZE)
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# 在 SRAM 里完成 softmax (编译器自动用 shared memory)
row_max = tl.max(row, axis=0) # 一行求 max
numerator = tl.exp(row - row_max) # 减 max 防 overflow
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 写回
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
对比纯 CUDA 写一个 fused softmax kernel:要手动 __shared__、手动 warp reduce __shfl_xor_sync、手动 block-level reduce、手动管 bank conflict——通常 100+ 行。Triton 这 15 行做的事情等价。
Autotune:让编译器选 block size
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
],
key=['n_elements'], # 输入大小变化时重跑 autotune
)
@triton.jit
def vector_add_kernel(...):
...
第一次调用某个 (n_elements 区间) 时 Triton 跑一遍所有 config 测速,选最快的存下。后续同 key 直接复用。这是 CUDA 没有的内建机制。
必做:Triton 官方 tutorial 1-3
- 01 vector add · 学 program_id / load / store / mask。
- 02 fused softmax · 学 block 内 reduce、kernel fusion。
- 03 matmul · 学 2D block tiling、accumulator、autotune。
BLOCK_M, BLOCK_N, BLOCK_K 三个参数分别决定什么?哪个最影响 SRAM 占用?vLLM 怎么用 Triton
vLLM 在以下场景偏好 Triton:
- AMD ROCm 支持:CUDA 不能跨编译到 AMD,Triton 可以。
- 实验性 attention:新 architecture(比如 sliding window、cross attention)先用 Triton prototype,稳定了再 CUDA-化(如果有性能差距)。
- 动态 shape:autotune 让一份代码自动适配不同尺寸。
- 非主流 dtype:fp8、int4 等组合,Triton 改起来快。
vllm/attention/ops/ 或 vllm/v1/attention/ops/ 下的 .py 文件csrc/attention/paged_attention_v1.cu),Triton 版"省"掉了什么?(答:手动 shared memory 管理、warp reduce、grid 配置——Triton 编译器代办)12动手 · 跑 benchmark 画图
本月的硬核作业。这一份数据将出现在你的简历/blog 上。
实验设计 · benchmark_throughput.py
vLLM 自带两个 benchmark 脚本,用法不同:
| 脚本 | 测什么 | 典型用法 |
|---|---|---|
benchmarks/benchmark_throughput.py | 离线吞吐 (offline, 不限延迟) | 给定一堆请求,跑完总耗时 |
benchmarks/benchmark_serving.py | 在线 serving (有 QPS 速率) | 模拟生产 traffic,看 latency 分位 |
benchmarks/benchmark_latency.py | 单请求延迟 | 固定输入测端到端 |
benchmark_throughput 的关键 flag
python benchmarks/benchmark_throughput.py \
--model meta-llama/Meta-Llama-3-8B \
--backend vllm \
--input-len 1024 \ # 每请求的 prompt 长度
--output-len 128 \ # 每请求的输出长度
--num-prompts 1000 \ # 总请求数
--dtype float16 \
--max-model-len 4096 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 256 \ # 最大并发 sequence
--enforce-eager \ # 关 CUDA graph (对照用)
--enable-prefix-caching # 开 prefix cache (对照用)
benchmark_serving 的关键 flag
# 启动 server
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3-8B --port 8000 &
# 跑 benchmark client
python benchmarks/benchmark_serving.py \
--backend vllm \
--base-url http://localhost:8000 \
--dataset-name sharegpt \ # 或 random / sonnet
--dataset-path ShareGPT_V3.json \
--num-prompts 1000 \
--request-rate 5 \ # QPS
--save-result \
--result-dir ./results/
变量矩阵 · 至少跑 4 张图
| 图 # | X 轴 | Y 轴 | 固定 |
|---|---|---|---|
| 1 | max-num-seqs (1, 8, 32, 64, 128, 256) | throughput tokens/s | input=1024, output=128, num-prompts=1000 |
| 2 | input-len (128, 512, 1024, 2048, 4096) | TTFT p50 / p99 | output=128, request-rate=5 |
| 3 | request-rate (1, 2, 5, 10, 20, 50) | throughput + p99 latency (双轴) | input=1024, output=128 |
| 4 | backend (FLASH_ATTN, FLASHINFER, PAGED_ATTN) | throughput (bar plot) | 统一其他配置 |
matplotlib 样本代码
import json, matplotlib.pyplot as plt
import numpy as np
# 假设跑完后 results 字典:results[max_num_seqs] = {'tput': ..., 'ttft_p50': ..., ...}
xs = sorted(results.keys())
tputs = [results[x]['tput'] for x in xs]
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.plot(xs, tputs, 'o-', linewidth=2, markersize=8, color='#c2410c')
ax.set_xlabel('max-num-seqs (concurrency cap)')
ax.set_ylabel('Throughput (tokens/s)')
ax.set_title('Llama-3-8B · A100 80GB · throughput vs concurrency')
ax.grid(alpha=0.3)
ax.set_xscale('log', base=2)
# 标注 saturation point
ax.axvline(x=64, linestyle='--', color='gray', alpha=0.5)
ax.text(64, max(tputs)*0.7, 'KV 显存撑爆点\npreempt 开始', ha='center')
plt.tight_layout()
plt.savefig('throughput_vs_concurrency.png', dpi=120)
每张图应该长什么样(预期形状)
对照实验 · 必做加分
- CUDA graph 影响:
--enforce-eager关掉 graph,对照默认。decode 重的负载下应该有 5-15% 差距。 - Attention backend 差异:
VLLM_ATTENTION_BACKEND=FLASH_ATTNvsFLASHINFERvsTRITON_ATTN_VLLM_V1。一般 5-30%。 - Prefix caching 影响:sharegpt 数据集有共享 system prompt,开 / 关 prefix caching 差距 1.5-3×。
- Quantization 影响:fp16 vs AWQ-int4 vs fp8。throughput 通常 1.2-2× (但要看长 ctx 还是短)。
怎么解读异常
| 现象 | 常见原因 | 排查方向 |
|---|---|---|
| throughput 在 max-num-seqs 大时下降 | preempt 频繁、KV swap 抖动 | 看日志 preempted 次数 |
| TTFT 在 batch 大时暴涨 | prefill 排队,且 chunked prefill 关 | 开 --enable-chunked-prefill |
| 不同 backend 差距很大 | fp 精度不同、kernel cover 范围不同 | 分别跑短 / 长 ctx 看 |
| 开 graph 反而慢 | batch size 跨桶 padding 浪费 | 对照 cudagraph_batch_sizes 列表 |
13常见反模式
读到一些自信但错的观点时,能站住脚很重要。下面四个最常见。
"Triton 比 CUDA 慢,生产要用 CUDA"
常常是错的。对中等复杂度的 kernel(element-wise + 一个 reduce + 适度 fusion),Triton 编译出来的 PTX 跟手写 CUDA 通常在 5% 以内。autotune 还能把好的 BLOCK_SIZE 选出来——手写 CUDA 没人会去遍历所有组合。
CUDA 真有优势的地方:极致的 fp8/fp16 mixed precision、warp 内特殊指令(mma.sp、cp.async.bulk)、为某代特定 GPU 优化的微调。FlashAttention v3 用 raw CUTLASS/CUDA 是因为 WGMMA、TMA 这些 Hopper-only 指令 Triton 还没原生支持。
结论:medium kernel → Triton;最热点的、跨多硬件代差的 kernel → CUDA。
"FlashAttention 总是赢"
错。短 context (N < ~256) 时,朴素 attention 本来就能塞进 SRAM 跑完,FA 的 tile schedule 开销不划算。极小 batch、超短 prompt 的边缘场景偶尔 XFormers 更快。
另外 FA 在 decode(query 数 = 1)是低效的——FA 的并行化基于 Q tile 数量。decode 阶段 Q 就一行,吃不满 SM。所以 vLLM 在 decode 用自家 paged kernel(并行化基于 KV block 数)。
"eager mode 没区别,反正最终都跑同样的 kernel"
错。eager mode 每个 op 都从 CPU 单独 launch kernel。对 LLM 那种"几百个小 kernel 一路过"的 workload,CPU launch overhead 攒起来 10%+。CUDA graph 能把这 10% 抢回来。
另一个 eager 看不到的好处:graph 内部 kernel 排队不抢 GPU 占用率,调度更平滑。pure latency-sensitive 场景(单请求、小模型)受益最明显。
"FP8 一定降速度"
错。Hopper 之前的硬件(A100 及之前)确实没 fp8 Tensor Core,模拟 fp8 反而慢。但 H100 起 fp8 是原生 Tensor Core 路径,算力比 fp16 快 2 倍。
fp8 的真问题是精度——某些激活的 outlier 会导致 quantization error 放大。要配合 calibration(per-tensor 或 per-channel 的 scale),加上小心选哪些层 quant。生产里 W8A8 经过验证的模型在 throughput 上能赚 ~1.5×,质量下降可控。
注意 H100 之外的卡(A100)跑 fp8 是软件模拟,会慢——这种情况要先看你的硬件。
14速补材料
必读
- CMU 15-418 Parallel Computing · 课程主页
- Lecture 2-3:缓存层级 + 内存延迟。讲得比 CS162 OS 课深。
- Lecture 5-7:GPU 架构(SIMT、warp)。
- Lecture 14:roofline 模型。如果时间紧张,就看这一节。
- CUDA C++ Programming Guide · NVIDIA 官方
- Chapter 4 (Hardware Implementation) · 讲 SM 内部结构、warp 调度。
- Chapter 5 (Performance Guidelines) · 讲 memory coalescing、occupancy。
- Chapter 6 (Math Functions) · 速度精度表,写 kernel 时要查。
- 一个下午翻完,每页都有用。
- PMPP (Programming Massively Parallel Processors) · Hwu, Kirk, Hajj 著(第 4 版 2022)
- Chapter 3-5:CUDA 基础。比官方 guide 教学性强。
- Chapter 8 (parallel patterns: reduction, scan):写 reduce 类 kernel 必读。
- Chapter 12 (sparse matrix):理解 PagedAttention 的 indirection 用得上。
- Triton 官方 docs · triton-lang.org
- Tutorial 1-3 必做(vector add, softmax, matmul)。
- Tutorial 6-7(FP8 GEMM)选做。
选读
- FlashAttention 论文(v1, NeurIPS'22) · 读懂思想 + 附录 A 的 online softmax 推导。
- FlashAttention-2 论文(2023) · 看 work partition 章节。
- FlashAttention-3 论文(2024) · Hopper 特化。
- PagedAttention 论文(SOSP'23) · 你 M2 已经读过。M4 重看 kernel 设计部分。
- NVIDIA H100 Whitepaper · TMA、WGMMA、新 Tensor Core 的官方说明。
- "Making Deep Learning Go Brrrr From First Principles" · Horace He 的 blog post,关于 PyTorch 视角的 memory-bound vs compute-bound,必读。
15本月 PR 候选
Month 4 的 KPI 是一个 kernel / 性能相关的 PR。候选方向,按门槛升序:
- benchmark 补全 · 给某个少 cover 的 workload 加 benchmark 脚本 + baseline 数据。改动小,价值高。
- 给 Triton kernel 加 autotune config · 已有 kernel 但 autotune 配置不全,加几个 BLOCK_SIZE 候选,跑 benchmark 提交报告。
- 修一个 numerical issue · fp16 vs fp32 累加导致的精度问题,搜 issue tracker
precision标签。 - backend selector 修小问题 · 例如某种硬件 / dtype 组合没正确路由到最优 backend。
- 新 GPU 架构适配 · H200、B100、AMD MI300 的 kernel cover 补全。门槛最高,需要相应硬件。
16本页自检
Month 4 结束时这些应该全部 ✓
17延伸阅读
这一章你下到了硬件层。再往前走的方向:
| 主题 | 论文 / 资源 | 为什么读 |
|---|---|---|
| 原始 FlashAttention | Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, NeurIPS 2022 | "IO-awareness" 概念的开创。M4 思想根基。 |
| FlashAttention-2 | Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, 2023 | 看 work partition 章节,理解 v1 → v2 的真正变化。 |
| FlashAttention-3 | Shah et al., FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision, 2024 | Hopper TMA + WGMMA 的实战示范。fp8 attention 范式。 |
| PagedAttention | Kwon et al., Efficient Memory Management for LLM Serving with PagedAttention, SOSP'23 | kernel 设计章节再读一遍,M2 的时候你关注 memory layout,这次关注 gather pattern。 |
| FlashInfer | Ye et al., FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving, 2025 | 第三方 attention 库,跟 vLLM 配合用。看它怎么处理 grouped-query attention、sparse attention。 |
| SGLang RadixAttention | Zheng et al., SGLang: Efficient Execution of Structured Language Model Programs, NeurIPS'24 | prefix tree 共享的高级版。kernel 视角看 KV 共享的极限。 |
| Online Softmax | Milakov, Gimelshein, Online normalizer calculation for softmax, 2018 | FlashAttention 数学根基的原始论文。3 页,读完心安。 |
| Triton 论文 | Tillet et al., Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations, MAPL'19 | Triton 的设计动机。读了之后 DSL 抽象一切都顺。 |
| Roofline model | Williams, Waterman, Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures, CACM'09 | 所有 perf 分析的祖宗 paper。GPU 圈大家都用它的概念但很少回头读原文,6 页。 |
| "GPU MODE" 录制课 | github.com/gpu-mode/lectures | 社区驱动的 GPU 编程课,多个 GPU vendor 的工程师讲。lecture 5-12 是 CUDA / Triton 实战,特别好。 |
读论文的顺序建议:FlashAttention v1 → online softmax 原 paper → PagedAttention(再读 kernel 章节)→ FlashAttention v2 → roofline → 选你感兴趣的方向继续。一周一篇够了。