Triton L2缓存命中优化矩阵乘法(fp16&int8)详解及性能测试¶
Ch01.203 Triton L2缓存命中优化矩阵乘法(fp16&int8)详解及性能测试¶
📊 Level ⭐⭐ | 17.6KB |
entities/triton-l2缓存命中优化矩阵乘法fp16int8详解及性能测试.md
文章概要¶
"L2缓存命中优化矩阵乘法"是Triton官方提供的第三个教程,本文将结合硬件特性对此部分内容进行详解。同时笔者也简单的做了下int8 matul的魔改,并进行了量化/非量化性能测试及分析。
基础概念¶
Triton虽然是python前端,但是编程思维和python还是有比较大的区别的,更倾向于cuda编程。但和cuda编程相比,Triton开发不需要手动管理线程块、网格、线程束等结构。此外,Triton 编译器会自动对代码进行优化,包括内存访问模式优化、指令调度、并行性优化等。Triton编译器能够根据不同的 GPU 硬件架构和输入数据大小,生成高效的机器代码,减少了开发者手动优化的工作量。 使用Triton编程最主要的目标是用来做 算子融合 。模型在训练时候,除了gpu的计算时间外,从显存(HBM)把数据搬运到SRAM也占用了很多时间。举个例子,算一个最简单的标准化(x-mean)/var,需要涉及到如下步骤:1.读x+读mean 2.写x-mean 3.读x-mean,读var 4.写(x-mean)/var,看起来十分的啰嗦,并且要保存中间的临时变量,占用额外显存。如果想让线程块端到端的计算结果(只读一次写一次),就需要做算子融合了。flash attention本质上就是一种算子融合,加速效果和显存节约量都比较显著,尤其是在长序列的时候。之前在训练Steel-LLM(https://github.com/zhanshijinwat/Steel-LLM)的时候,笔者专门消融过算子融合带来的训练加速效果,可以看我的往期文章(https://zhuanlan.zhihu.com/p/694223107),即使仅对RMSNorm做算子融合,训练也有10%左右的吞吐提升,显存节约了4g。 接下来,了解一点GPU相关的基础知识(具体数值是A100显卡的),以便后边更好的理解Triton编程。先来看看GPU SM(Streaming Multiprocessor,流式多处理器),其是GPU上的基础硬件单元,由如下几部分组成:
- CUDA核心:用于执行计算指令的基本单元。在一个 SM 中包含多个 CUDA 核心,它们可以并行地执行相同或不同的指令。例如,在进行矩阵乘法运算时,多个 CUDA 核心可以同时对矩阵的不同元素进行计算,大大提高了计算速度。
- 寄存器:GPU 中速度最快的存储单元,SM 中的寄存器文件用于存储线程在执行过程中的临时数据。每个线程都有自己独立的寄存器空间,线程可以快速地读写寄存器中的数据。
- 共享内存:线程块(SM中可有多个线程块,例如A100的每个SM最多可以支持32个线程块)内的线程可以共享这块内存区域。线程可以通过共享内存快速地交换数据,实现线程间的协作和通信。例如,在矩阵乘法中,一个线程块内的线程可以将矩阵的一部分数据加载到共享内存中,然后共同对这些数据进行计算,减少了对全局内存的访问次数。
- L1 cache:位于 SM 内部,用于缓存从全局内存或更高层次存储中频繁访问的数据。当线程需要访问数据时,首先会在 L1 Cache 中查找,如果找到则直接获取,避免了从全局内存读取数据的高延迟。
- 线程调度器:负责管理和调度线程块的执行。它会根据线程块的资源需求(如寄存器、共享内存等)和执行状态,将线程块分配到合适的 CUDA 核心上执行。当一个线程块因为等待内存数据等原因暂停执行时,线程调度器会切换到另一个就绪的线程块继续执行,以提高 SM 的利用率。 对于GPU来讲主要下几种存储层次结构,从上到下读取读取速度依次递减、但容量逐渐递增: 当 GPU 线程需要读取数据时,首先会在 L1 Cache 中查找,如果找不到会进一步到 L2 Cache 中查找,还是找不到的话才回去HBM中找。如果我们能够通过一些数据组织策略,让线程能更大概率的在L1 Cache/L2 Cache中找到想要的数据,那么就可以提高数据读取速度。这就呼应了我们的标题,为啥要对矩阵乘法做L2 Cache的命中优化。 在讲解L2缓存命中优化矩阵乘法之前,需要再强调一下block的概念。Triton是围绕block的抽象逻辑进行编程的,block内执行任务靠多线程,同时也有线程间的共享内存,但其和gpu的线程块又不完全等价。gpu的线程块更"物理"一些,在cuda编程时需要手动管理线程块的维度、索引等细节。而Triton的blcok更"逻辑"一些,Triton 编译器会自动处理很多硬件相关的调度和优化,开发者只需关注 block 的计算逻辑和数据处理。 See also Harness Production Agent Engineering Deficit
L2缓存命中优化矩阵乘法¶
读者可以大致先过一下官方教程,看下代码整体结构: Matrix Multiplication:https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html 首先我们来看一下Triton实现的矩阵乘法函数matmul_kernel是如何调用的:
def matmul(a, b, activation=""):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](../ch01-a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), ACTIVATION=activation)
return c
分组计算:利用 L2 Cache 的关键¶
按行主序计算:假设A、B、C都是99的,如果想算出C的第一行,A矩阵一共需要用到9个block,B矩阵一共需要用到81个block,共计90个block。 按group计算:同样是计算出C的9个block,A需要用到27个block,B也需要用到27个block,共计54个block。和行主序计算矩阵C相比,读取数据量其实并没有降低,只不过读取数据的unique数降低了,A矩阵的某些行/B矩阵的某些列会被重复利用。 L2 Cache 优化原理:按group划分的矩阵乘法在概率上能更好的利用时间局部性和空间局部性。从L2 Cache读取数据的速度会比从HBM读数据更块,而那些被重复复用的数据更容易被保存在L2 Cache上,当线程能够从L2 Cache拿到数据时,就不去HBM上找了,进而提高数据读取速度。* 当M大于15000时,行主序的triton matmul性能骤降,是因为L2 Cache此时被打满了,行主序的matmul不利于L2 cache的命中率,更多的数据需要从HBM读导致性能下降。
int8精度下的Triton Matmul¶
关键差异:累加器需要用int32格式,不然int8相乘累加之后非常容易溢出被截断,进而损失精度。 量化方案:
scale_a = (a.abs().max() / 127.0).clamp(min=1e-8)
scale_b = (b.abs().max() / 127.0).clamp(min=1e-8)
a_int8 = (a / scale_a).round().clamp(-127, 127).to(torch.int8)
b_int8 = (b / scale_b).round().clamp(-127, 127).to(torch.int8)
性能测试结论¶
在 RTX 4090 上的测试结果: 1. fp16 场景:做 group 优化的 triton matmul 一直能和 torch 实现的 matmul(cuBLAS实现)性能相近。当 M 大于15000时,行主序的 triton matmul 性能骤降,因为 L2 Cache 此时被打满了。 2. int8 场景:行主序的 triton matmul 和做了 group 优化的 matmul 相比性能虽有下降,但劣化程度比 fp16 更低。这是因为在低精度情况下,L2 cache 能存更多数量的数据,进而提高 L2 Cache 命中率。 3. int8 vs fp16 收益:int8 类型在计算和传输上都有优势,矩阵的 dim M 越大,int8 收益越大。deepseek-v3 这种超大号的模型使用 fp8 收益会非常显著,但小模型上的意义就不是太大了。小模型对低精度也更敏感。
深度分析¶
1. L2 Cache 命中优化的本质:数据访问模式¶
本文的核心贡献是揭示了矩阵乘法中数据访问模式对性能的决定性影响。按 group 计算比按行主序计算更快,不是减少了数据总量,而是提高了数据复用率——相同的 A 矩阵行和 B 矩阵列在 L2 Cache 中被多个计算任务重复使用,从而减少了从 HBM 读取数据的次数。 这个发现对 GPU 编程有普遍的指导意义:GPU 优化的核心不是计算密度,而是内存访问模式。即使计算量相同,不同的数据访问模式可能导致数倍的性能差异。在设计 GPU kernel 时,首先要考虑的不是如何充分利用计算单元,而是如何让数据尽可能长时间地留在 Cache 中。
2. Triton 的抽象层次:平衡易用性与性能¶
Triton 的设计哲学是"让开发者只关心计算逻辑和数据处理,而把硬件调度和优化交给编译器"。这个抽象层次是经过权衡的——比 CUDA 更高(不需要手动管理线程块、网格、线程束),但比 PyTorch 的 torch.matmul 更低(允许精细控制数据布局和内存访问模式)。 triton.autotune 装饰器是这个设计哲学的典型体现:开发者只需定义多组超参数组合,Triton 在运行时自动搜索当前硬件条件下的最优配置。这种运行时自动调优的机制,解决了"不同硬件需要不同超参数"的工程难题。
3. int8 量化:精度与性能的博弈¶
int8 量化在矩阵乘法中的应用揭示了一个深刻的工程矛盾:精度损失 vs 计算/内存收益。 文章中的关键观察是:
- int8 累加容易溢出,需要用 int32 累加器来缓解
- int8 的收益在大矩阵场景下更显著,因为大矩阵对内存带宽的需求更高
- 小模型对低精度更敏感,因为小模型的表示能力本身有限,量化带来的信息损失占比更大 这说明 int8 量化不是一个"通用优化",而是需要根据模型规模、任务类型、数据分布来综合评估的策略。DeepSeek-V3 这样的大模型可以用 FP8/FP4,是因为它们的表示冗余度足够高,可以承受量化误差;而小模型则可能因为量化而导致显著的能力下降。
4. Flash Attention 与算子融合的联系¶
文章指出"flash attention 本质上就是一种算子融合"——这个观点揭示了现代 GPU 优化中一个重要的范式:融合相关的计算步骤,减少中间结果的内存访问。 传统的模型训练中,RMSNorm 这样的简单操作也需要多次读写 HBM,而算子融可以将"读x+读mean → 写x-mean → 读x-mean,读var → 写(x-mean)/var"压缩成"读一次、写一次"的高效实现。这种优化带来的收益(10% 吞吐提升 + 4GB 显存节约)在训练场景中是相当可观的。
实践启示¶
对 GPU 性能优化的启示¶
- 先优化内存访问模式,再优化计算:在进行 GPU kernel 优化时,首先应该分析数据访问模式,而不是一味地增加计算量。按 group 计算矩阵乘法的成功案例说明,优化 Cache 命中率的回报通常高于优化计算密度。
- 利用运行时自动调优:Triton 的
autotune机制是一个值得借鉴的工程实践——不要在代码中硬编码"最优"超参数,而是让系统在运行时自动搜索最优配置。这对于需要适配多种硬件的平台型产品尤为重要。 - 小规模测试可能产生误导:文章指出,在小规模测试中,按 group 计算和按行主序计算的性能差异不明显,因为 L2 Cache 足够容纳所有数据。这提醒我们,性能测试必须在真实规模下进行,小规模基准测试的结果可能完全无法推广到生产环境。
对模型训练工程的启示¶
- 算子融合是性价比最高的优化之一:相比硬件升级或算法改进,算子融合的投入产出比通常更高。即使只对 RMSNorm 这类"小操作"做融合,也能带来 10% 的吞吐提升和 4GB 的显存节约。
- 量化策略需要分场景评估:int8 量化在大模型上收益明显,在小模型上则要谨慎。这提示我们在实际项目中,应该建立量化敏感性评估流程,对每个模型/每个任务进行量化前后的效果对比,而不是一刀切地应用量化。
- 关注 L2 Cache 的容量边界:文章中的性能骤降点(M > 15000)对应的是 L2 Cache 被打满的临界点。在实际应用中,应该了解目标硬件的 L2 Cache 容量,合理设计数据分块大小,避免性能突然下降。
对 AI Infrastructure 工程的启示¶
- Triton 是构建 AI 基础设施的好选择:相比 CUDA,Triton 的 Python 前端大大降低了开发门槛,同时编译器自动优化的特性使得代码可以在不同硬件上高效运行。对于需要快速迭代的自定义算子,Triton 是比 CUDA 更务实的选择。
- 混合精度策略的工程价值:文章展示了 fp16 + int8 混合使用的可能性——fp16 用于计算密集型操作,int8 用于内存带宽瓶颈操作。这种混合精度策略可以在保持模型质量的同时,获得接近 int8 的性能收益。
- 性能测试要覆盖多种矩阵规模:文章中 M > 15000 时的性能骤降提醒我们,性能特性可能随输入规模非线性变化。在做性能基准测试时,应该覆盖各种可能的输入规模,而不仅仅是常用的那几个。