系列的前两篇文章写了线性注意力基本的理论和工程(Kernel),但是相关的内容只是针对开发环境的一些讨论,实际上在大规模部署时,我们为了更好地利用资源,往往需要做很多优化,这篇文章就从这些落地优化的角度聊一下线性注意力,以及为什么我们会认为现在线性注意力的infra只是在初级阶段。
量化
LLM.int8() 和 Massive Activations
早在ChatGPT爆火前,就有一些研究讨论了Full-Attention架构在量化时遇到的问题,在LLM.int8()这篇论文中,就提到在Full-Attention的激活值中有部分离群点会影响量化效果,因此建议把离群点和其他值分开处理:
这一现象被称为Massive Activations,目前业界对这个现象进行了大量的研究,其中两个主要因素就是Softmax 和 RoPE。
Softmax 归一化:Softmax有着强制要求“总和为1”的特性,但是在实际的LLM运算中,我们往往有更复杂的需求(比如有时我们输入模型的问题比较简单,一个60层的网络可能在前40层就已经完成了计算,对于后面层我们需要的是一个什么都不做的操作,那么在残差网络的帮助下我们其实需要一个0输出),为了满足这种需求就会出现类似Attention Sink的现象,导致注意力集中于单一无关标记,同时产生massive activation。
RoPE位置编码:在文章Massive Values in Self-Attention Modules are the Key to Contextual Knowledge Understanding就提到:RoPE在Query/Key的通道维度对上应用了带有频率性角度参数的旋转变换,该变换依赖相对位置,同时在某些维度的投影上显著放大或缩小分量,从而导致部分维度“成簇集中”,出现了异常的spike值。 实际上,该文章也在实验中发现凡是采用 RoPE 的模型(如 LLaMA、Qwen、Gemma、Falcon 等)都会展示这种 massive values stripes,而不使用 RoPE 的模型(如 GPT-2、OPT、Jamba)则没有这一现象。
这段时间主要在做推理加速相关的工作,也做了一些实验,准备写点文章做一些记录。这篇文章就从“移动计算”和“移动存储”的视角开个头,聊聊并行策略的动态切换和KV Cache 流动管理的一些实践,梳理一下在 Agentic Workflow 日益普及的当下,我们能否通过对算力、存储与带宽的调度,在大规模集群上更好地提升推理效率。
在传统的大数据时代,“计算中心”还是“数据中心”一直是个有趣的问题。随着技术的发展,大家也逐渐总结出了 “Move Compute to Data” 的实践经验:因为 SSD 很贵、IO 很贵、带宽也很贵。相比之下,不如把代码调度过去,使用本地的 CPU 进行计算。因此,调度器的核心任务是保证 Data Locality,尽量把计算分发到数据所在的硬盘旁边。
但在 LLM 的时代,我们面对的是超大模型在 GPU 上的推理,移动计算已经变成了移动模型(权重)这个巨无霸。相比之下,似乎还是移动数据(KV Cache)更现实一些——但什么才是最合适的解法呢?
1. 需求的演变:从 Chat 到 Agent
系统架构的演进,必然是随着业务不断演变的。
第一阶段:单轮指令
- 特征:用户发送一句指令任务(如翻译、总结),模型执行并回复。请求之间几乎独立。
- 瓶颈:纯粹的算力(Prefill)或显存带宽(Decoding)。
- 调度:最简单的加权轮询。此时 KV Cache 的存在感很低,除了每台机器都有的 System Prompt,几乎没有状态复用的需求,我们可以随意把请求调度到任一台机器上执行。
第二阶段:多轮对话
- 特征:多轮对话可以通过 Prefix Caching 复用前面上文。Context 越来越长,每次对话对应一次交互。
- 瓶颈:显存容量 –> Prefill 时间
- 调度:开始引入 Affinity(亲和性) 调度——为了命中 Cache,我们尽量把请求发给存储了该用户历史数据的节点。也就是 “Move Compute to Data”,因为此时 Prefill(重算数据)太贵,而搬运 KV Cache 也还没在大规模集群中普及。但是这也会导致热点问题。
第三阶段:Agentic Workflow
- 特征:系统提示词、工具定义、思维链、上下文可以在并行的分支任务中共享,多轮对话可以并行执行,但对应一次交互(用户从感知多次交互变成感知任务完成)。
- 瓶颈:极其复杂的依赖关系,以及直接复用 KV Cache 带来的负载不均衡。
- 调度(面临的问题):如果 “Move Compute to Data” 会导致严重的热点问题——存有热门 Context 的节点会被打爆,而其他空闲节点却因为没有数据而帮不上忙。
可以看到,随着需求的变化,我们不再只关注单次请求的 TTFT/TPOT,而是开始关注整个 Agent 任务的 任务完成时间 以及系统的 总吞吐量。
本文重点参考了文章Mamba: The Hard Way和开源项目flash-linear-attention。
Prefill与Decoding
我们都知道,在Attention的计算中,Prefill和Decoding是两个不同的场景,具体特性如下:
| 特性 | Prefill | Decoding |
|---|---|---|
| 输入 | 长序列(长度 $L$) | 1 个新 token + 历史状态 |
| 常见瓶颈 | Compute bound(Tensor Core 利用) | Memory/Latency bound(状态读写 + 小矩阵计算) |
在回忆一下理论篇的介绍,特别是关于Mamba章节中的推导,常见的Linear Attention有两种表示格式:
矩阵格式(Attention视角):
$$ y_i = \sum_{j=0}^i (CausalMask(Q_i K_j^T)) V_j $$
递推格式(SSM视角):
$$ h_t = A_t h_{t-1} + B_t x_t $$ $$ y_t = C h_t $$
其中Decoding的算子可以比较直接的使用递归格式进行计算,因此我们本文重点还是看Prefill的实现。
Linear Attention常见算法
在Linear Attention的计算中,有一些常见的思路,本章结合flash-linear-attention的实现,对这些思路进行讲解。
Prefix Scan / Cumsum 前缀和
算法简介
先讲一下什么是Prefix Scan算法,简单来说Prefix是针对有结合性的算子提出的一种优化方式
假设有 $y_t = x_1 ⊗ x_2 ⊗ x_3 … ⊗ x_t$, 如果⊗支持结合律,那么y_t就可以用一个简单的Reduce进行求解
准备写一些关于线性注意力的文章,对相关理论和工程(Kernel)做一些梳理,这一篇是关于基础理论的。
Softmax Attention到线性注意力
Softmax Attention与O(N^2)复杂度
标准的Transformer使用的是Softmax Attention。给定查询(Query)、键(Key)、值(Value)矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$,其中 $N$ 是序列长度,$d$ 是特征维度(通常 $d \ll N$)。Attention的计算公式为:
$$ Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$
让我们仔细分析一下这个计算过程的维度变化:
- 计算相似度矩阵:$QK^T$。
- $Q$ 是 $N \times d$,$K^T$ 是 $d \times N$。
- 相乘得到 $N \times N$ 的矩阵。这一步的计算复杂度是 $O(N^2 d)$。
- 应用Softmax:对每一行进行归一化,维度不变,仍为 $N \times N$。
- 加权求和:乘以 $V$。
- $N \times N$ 的矩阵乘以 $N \times d$ 的矩阵 $V$。
- 结果是 $N \times d$。这一步的计算复杂度也是 $O(N^2 d)$。
瓶颈所在: 由于Softmax是非线性的,我们必须先完整地计算出 $N \times N$ 的Attention Matrix。
- 计算复杂度:$O(N^2 d)$。随着序列长度 $N$ 的增加,计算量呈平方级增长。
- 显存占用:$O(N^2)$(如果不使用FlashAttention等优化)。需要存储 $N \times N$ 的矩阵。
- KV Cache:在推理时,为了避免重复计算,我们需要缓存所有的 $K$ 和 $V$,显存占用为 $O(Nd)$。
之前大部分情况下这个问题不大, 但是随着Agent能力的发展,Context越来越长,这个 $O(N^2)$ 的复杂度就变得有些过高了。
背景
从年初 R1 火起来以后,看到了很多关于智能本质的讨论,目前大部分讨论似乎都认为:语言才是智能的基础,是通往 AGI 的路径。
比如张小珺[1-3]今年采访了很多大佬,就有两个这样的例子:
-
杨植麟[1]:认为在现有范式下,多模态能力往往难以提升模型的“智商”,甚至可能损伤模型已有的语言智能。他在接受采访时提到:“如果你想给模型加多模态能力,你需要确保这个多模态能力不要损伤它的‘脑子’。多模态能做到不损伤已经很好了。”他指出,在多模态模式下希望模型和纯文本模式共用一个“大脑”——也即多模态部分应尽量借用文本模型已有的智力,而不是开启另一套全新参数,否则可能丢掉原来文本部分学到的能力。
-
姚顺雨[2]也强调语言在通向通用智能方面更具潜力。他起初从事计算机视觉研究,但后来直觉告诉他语言才是更核心、更有潜力的方向,因此读博后转向了语言模型的研究。他指出,语言是人类为了实现认知泛化而发明的最重要工具,具有生成和推理的闭环特性,这使其成为构建通用智能系统的关键媒介。
不过还是有一些其他的观点,认为多模态能力的潜力可能被训练范式所限制,而非多模态本身毫无助益。例如阶跃星辰的张祥雨[3]就提到:
-
图文数据效果不好的原因是噪声数据:如果简单用图文混合数据进行训练,但不解决思维链(CoT)推理或任务复杂度的问题,模型的学习可能是有害的。在缺乏正确推理指导的情况下,模型每一步可能得不到有效信息,产生错乱的梯度,训练效果要么毫无提升,要么甚至更糟。这与他在一个万亿参数多模态模型项目中的发现一致:模型规模增大后,其数学和逻辑推理能力不仅没有提升,反而在达到平台期后开始下降。原因在于模型倾向于跳步直接给出答案而非踏实推演,从而累积误差。简单扩大模型或混合数据并不能自然地融合视觉与文本能力,背后缺少关键环节。
上一篇文章我们聊了一下Ray与LLM强化学习框架设计,探讨了其架构的演进,但是没有提到为什么框架会往这个方向逐渐演进而不是一开始就使用现在的设计。这里面自然有实践中不断优化的结果,但是也是和整个LLM RL需求的变化密切相关的。
因此,本文会主要讨论一下LLM强化学习中,算法与系统框架是如何相互影响、协同演化的。首先分析两个相对成熟的协同设计案例,然后讨论几个正在不够成熟、但是在笔者看来很有潜力的优化方向。
算法与框架协同演化的典型案例
先讲两个目前已经相对得到了共识的问题:
案例一:推理模型驱动的分离式架构设计
在前一篇文章中提到过,之前的强化学习还是希望尽量on-policy的,因此早期的强化学习系统倾向于采用on-policy算法配合co-located的架构设计。这种选择有其合理性:算法层面,on-policy确实具备样本效率优势;系统层面,在CoT和Test-time Scaling兴起之前,模型输出长度差异相对较小,推理引擎产生的计算空泡还算可控。尽管资源利用率的问题客观存在,但on-policy算法的优势在一定程度上抵消了这种系统层面的低效。
不过,从去年o1发布开始,业界开始重视Test time scaling,加上R1的发布又给推理模型点了一把火,这种范式的改变打破了之前的平衡:模型在推理阶段生成的文本长度显著增加,且不同样本间的长度差异极为悬殊。在这种新的计算模式下,on-policy的算法优势已无法弥补系统资源的巨大浪费——所有并行环境必须等待最长推理任务完成才能进入下一轮迭代,这种同步约束造成了严重的吞吐量瓶颈和大量计算资源闲置。
最近LLM强化学习框架发展特别快,Ray作为被ChatGPT带火的框架,在LLM各个训练阶段中,RL阶段的应用应该是最多的。写篇文章记录一下这块发展的脉络和一些看法。
从Google Pathways说起
讨论Ray和RL系统,得从Google的Pathways系统开始:2021年Google提出了Pathways作为下一代AI架构和分布式ML平台,在相关文献中详细讨论了Single-Controller + MPMD的系统设计。
Single-Controller(单控制器)是指用一个中央协调器来管理整个分布式计算流程的架构模式。在这种设计中,有一个主控制节点负责整个计算图的执行,包括任务分发、资源调度、状态监控等。
Multiple-Controller(多控制器)则是指使用多个分布式控制节点来协同管理计算任务的架构模式。在这种设计中,没有单一的中央协调器,而是由多个控制器节点分别负责不同的子系统或计算子图,通过分布式协调协议来实现全局一致性。
在Ray中的Driver Process就可以被作为一个典型的Single Controller来启动不同的任务程序,而通过torchrun运行的PyTorch DDP分布式计算则是在每个node上各自执行自己的程序则属于典型的Multiple Controller范式。
做Ray Platform也快2年了,遇到过各种的问题,整理一些踩过的坑看一下。
先从我们自己最常用的Ray Data开始,看看最常见的OOM/OOD问题,这个问题很多时候都是和反压相关的。
说是Ray Data,不过这里的反压不止一层,大概包括下面几个地方:
- Ray Core Generator:针对Ray Generators的控制,防止后台生成的数据过多导致OOM/OOD。
- Streaming Executor + Resource Allocator:
- 针对正在执行的任务,控制生成结果的速度,避免单个任务生成的数据过多导致OOM/OOD。
- 针对单个Operator,控制提交任务的数量,避免在资源紧张时提交新任务。
- Backpressure Policies: 其他关于任务提交的反压规则。
下面我们逐层分析这些机制的实现。
Ray Core Generator:对象数量反压
Ray Generator 类似Python Generator,用来作为迭代器进行遍历,但是和Python Generator有一个很大的不同在于:Ray Generator使用ObjectRefGenerator在后台持续执行。也就是说如果Ray Data的单个read_task需要读取一个很大的文件时,没法通过控制拉取任务产出的速度来控制任务的内存占用。(不管下游是否主动拉取,都会持续读取新的数据block。)
准备对DeepSeek的开源项目整理一些文档,也顺便强化一下记忆,先从FlashMLA开始。
FlashMLA是DeepSeek开源的MLA算子实现,这个实现主要给inference decoding用的,Training和prefill应该是另外一个算子。
先拿下面的图表示一下MLA算子是在计算一个什么东西,这篇文章就不讲具体的推导了,反正这个算子大概就是下面的2个GEMM算子的融合。需要注意的是:
- 这里矩阵K和矩阵V的共享一部分参数。
- 图里只画显示了一个Query Head和一对KV Head的计算。在实际计算中还要num_kv_head和batch_size两个维度。
- 两个GEMM中间其实还有一个sotfmax,不过这里可以通过online softmax算法把这块逻辑独立处理分块处理,所以不影响主流程。
Kernel的调用主要分两部分
- 调用
get_mla_metadata来计算一些metadata,用来优化kernel的执行 - 调用
flash_mla_with_kvcache进行计算
在进入调用前,先大概说一下FlashMLA计算的拆分逻辑。这块和FlashDecoding很像,并没有要求一个thread-block必须处理一个完整的sequence,而是通过一个负载均衡算法,把所有的sequence放到一起,然后拆分成一个个的sequence-block,然后每个thread-block就去处理分配给它的那些block的计算,最后再把这些thread-block的结果用合并,得到正确的输出。