home/tutorial/M4 Kernel

Month 4 · Kernel & 性能

下到 CUDA 层。理解为什么 FlashAttention 这种"重写 attention"的东西能快 3 倍。

Month 1-3 一直在 Python 层。这个月下到 csrc/ 和 Triton。 目标不是能徒手写 CUDA 算子,而是能读懂、能调试、能给已有 kernel 加 benchmark。 这一层最难补,但回报是面试时立刻显出底子。

驱动问题
Attention 的数学公式 softmax(QK^T / sqrt(d))V 谁都会写。 PyTorch 一行 F.scaled_dot_product_attention 就能跑。 那 FlashAttention 凭什么快 3 倍?它没改公式,到底改了什么?
先猜,再展开答案

关键洞察:现代 GPU 的瓶颈不是算力,是显存带宽。

朴素 attention 要把 N×N 的 attention matrix 实例化在 HBM(GPU 主显存)。 对 N=4096,这个矩阵 16M 个 float,一来一回搬 64MB 数据。 SM (compute unit) 大部分时间在等显存,不在算。

FlashAttention 的做法:

  1. 把 Q、K、V 切成 tile,每次只把一对 tile 加载到 SRAM(GPU 内 L1 级缓存,几十 KB,巨快)。
  2. 在 SRAM 里算完这对 tile 的部分 attention,累加到输出。
  3. 用"online softmax"技巧让 tile-by-tile 计算等价于一次性算。
  4. 从头到尾 attention matrix 不实例化,只有最终 O 矩阵搬回 HBM。

本质 = IO-aware,工艺级地少搬数据。算力没省,但显存搬运省了一个数量级。

01GPU 内存层级 · 必知

这是最基础的硬件背景。看完这张图,FlashAttention 才能 click。

寄存器 每 thread 256 个 SRAM (shared memory) 每 SM ~100KB · 19 TB/s L2 cache ~40MB · 5 TB/s HBM (显存) A100: 40-80GB · 2 TB/s 所有 PyTorch tensor 默认在这 Host RAM PCIe 32 GB/s · vLLM swap-out 的去处 ↑ 快 ↓ 大
每往下一层,容量大 10×、带宽小 10×。把数据"留在上面" = 性能。

记住的数字:SRAM 比 HBM 快 10×。任何能减少 HBM 访问的优化,回报都是显著的。 FlashAttention 是工程界 IO-aware 的代表作。

02FlashAttention 的核心 trick · online softmax

softmax 看起来不可分(需要先求总和再归一):

softmax(x_i) = exp(x_i) / Σ exp(x_j)

看似要等所有 x_j 算完才能算分母。online softmax 的技巧:分块算时,每块结束维护一个 running max + running sum,下一块来了就校正之前的累加值。数学等价、可分块。

💡 不用从头推 online softmax
FlashAttention 论文附录 A 有完整推导。你只需要知道"能等价做到",不必背公式。 真要写算子时再深入。

03vLLM 用了哪些 kernel

vLLM 不写自己的所有 kernel。生态分工:

组件来源位置
朴素 attention (prefill)FlashAttention v2/v3 (Dao-AILab)外部依赖 flash-attn
Paged attention (decode)vLLM 自写csrc/attention/paged_attention_*.cu
Triton attentionvLLM 自写 (跨架构)vllm/attention/ops/vllm/v1/attention/
RMSNorm / RoPE / activationvLLM 自写小算子csrc/layernorm_kernels.cu
GEMM (大矩阵乘)cuBLAS / cuBLASLtNVIDIA 闭源
MoE expert kernelvLLM + 社区vllm/model_executor/layers/fused_moe/

"vLLM 自写"的最多的是 paged attention——因为它要支持非连续 KV 地址(block table 索引),这是其他 attention 库不直接支持的。

04读源码 · csrc/attention

这是本月最硬核的一段。看不懂不要怕,目标只是认出结构。

读 ① · 1 小时 · 入口
目录扫一眼。注意有 paged_attention_v1.cupaged_attention_v2.cu。v2 在 v1 基础上加了 partition (分片处理超长 sequence)。
v2 比 v1 多解了什么问题?(答案在 source 顶部的 comment block)
读 ② · 2 小时 · 一个完整 kernel
csrc/attention/paged_attention_v1.cu · paged_attention_v1_kernel
CUDA C++ 函数。注意几个东西:
  • __shared__ 修饰的变量 = SRAM 上的 buffer。
  • 外层 for 循环遍历 KV block,里面 for 遍历 token in block。
  • block_table 索引把逻辑 block → 物理 block 偏移。
  • online softmax 的 max 和 sum 累加变量。
同一个 thread block 里的 thread 在协作做什么?(答:每个 thread 处理一个 head dim,最后归约出 attention 输出)
读 ③ · 1 小时 · 绑定到 Python
csrc/attention/attention_kernels.cupybind.cpp
CUDA function 怎么注册成 Python 可调用。看 TORCH_LIBRARY_FRAGMENTPYBIND11_MODULE
从 Python 调一次这个函数,传入参数怎么变成 GPU 上的 tensor 地址?

05Triton · 比 CUDA 友好的中间层

CUDA 上手太陡。Triton 是 OpenAI 出的Python-like DSL,编译成 CUDA。 vLLM 用 Triton 写跨架构 kernel(AMD ROCm 也能跑)。

动手任务(不可跳):

做 · Triton 官方 tutorial 1-3
  • 1: vector add — 学 program_id / load / store。
  • 2: fused softmax — 学 SRAM 上做归约。
  • 3: matmul — 学 block tiling、autotune。
做完 3 个,你已经能写 80% 的 element-wise 和 reduce 类 kernel。

做完 tutorial 回头看 vLLM 的 Triton kernel:

读 · vLLM Triton attention
vllm/attention/ops/paged_attn.py (或 v1 路径下)
Python 写的 Triton kernel。比 CUDA 版可读 5 倍
对比同功能 CUDA 版,Triton 版"省"掉了什么?(答:手动 shared memory 管理、thread 协作、grid 配置——Triton 编译器代办)

06动手 · 跑 benchmark 画图

本月的硬核作业。这一份数据将出现在你的简历/blog 上。

实验设计

  1. 在云上 A100 / A10 跑 benchmarks/benchmark_throughput.pybenchmark_serving.py
  2. 变量:batch size (1, 8, 32, 64)、prompt 长度 (128, 1024, 4096)、output 长度 (64, 256)。
  3. 记录:throughput (tokens/s)、TTFT、p50/p99 latency。
  4. 用 matplotlib 画图:
    • throughput vs batch size(应该 sub-linear 增长)
    • TTFT vs prompt length(应该线性)
    • 显存占用 vs 并发数

对照实验(可选但加分)

✓ 你将看到的现象
  • Throughput 随 batch size 增长但不线性:到某个点 KV 显存撑爆,preempt 开始,曲线压平。
  • TTFT 在大 batch 下显著上升(prefill 排队)。
  • FlashAttention vs XFormers 差异在 ~10-30%。
  • CUDA graph 在 decode 重的负载下给 5-15% 提速。

07本月 PR 候选

⚠ 写 CUDA PR 前注意
Kernel PR 是最难 merge的——reviewer 严、CI 慢、需要多种 GPU 测。 推荐先从 Triton kernel / benchmark 切入,建立信任,再碰 CUDA 主战场。

08速补材料

09本页自检

Month 4 结束时这些应该全部 ✓