线性代数在 LLM 里到底做了什么?
从表示、变换、交互到低秩结构
最近在系统回顾 LLM 相关材料的时候,我发现一个很微妙的问题:很多概念我好像都“知道”,也能在熟悉的语境里使用;但一旦要求我从头解释,解释到它为什么是这个形式、解决了什么问题、在 LLM 里具体扮演什么角色,语言就会变得不够利落。
要重新理解 LLM,线性代数几乎是绕不开的第一层。LLM 里最基本的对象一开始就是向量和矩阵:hidden state 是向量,linear layer 是矩阵乘法,attention 里有 dot product,LoRA 里有 low-rank update。这些说法都对,但它们更像是入口。真正需要想清楚的是:为什么信息可以被表示成向量?矩阵乘法到底在变换什么?attention 为什么能让不同 token 交互?rank 和 SVD 为什么会自然地连接到压缩和 LoRA?
这篇blog想顺着这些问题,重新整理一次线性代数在 LLM 里的角色。我更想做的是把那些熟悉但容易说得含糊的概念拆开:从表示开始,看向量如何承载信息;再看矩阵如何在每个位置内部改写表示;然后看 attention 如何用匹配和加权混合让 token 之间通信;最后再看训练出来的矩阵里,为什么会出现可以被低秩结构捕捉的方向。
如果最后能把“LLM 里到处都是矩阵乘法”这句话讲得更具体一点,这篇文章就达到目的了。矩阵乘法当然重要,但我更关心的是它背后的那套线性代数语言:表示、变换、交互,以及训练后逐渐显现出来的结构。
简单回顾:LLM 的基本计算流程
先简单回顾一下一个 decoder-only Transformer 的一次 forward pass。从最外层看,模型把一串 token ids 变成下一个 token 的概率分布:
\[\text{token ids} \rightarrow \text{embeddings} \rightarrow \text{transformer blocks} \rightarrow \text{logits} \rightarrow \text{next-token probabilities}\]这条流程里反复被更新的对象,是一整段序列的 hidden states。模型先通过 embedding table 把每个 token id 查成一个 $d$ 维向量;位置相关的信息也会进入这条状态流或 attention 计算中。假设序列长度是 $n$,模型宽度是 $d$,那么最初的状态矩阵可以写成:
\[H^{(0)} \in \mathbb{R}^{n \times d}\]接下来,LLM 的主干计算可以看成这个状态矩阵被每个 transformer block 反复更新。第 $\ell$ 层接收 $H^{(\ell)} \in \mathbb{R}^{n \times d}$,输出下一层的 $H^{(\ell+1)}$:
\[H^{(\ell+1)} = \mathrm{Block}_{\ell}(H^{(\ell)})\]在 block 内部,attention 和 MLP 会把状态矩阵投影到不同用途、不同维度的空间里计算;这些计算结果最后会投回模型宽度 $d$,再通过 residual connection 加回原来的状态矩阵。所以层与层之间传递的 hidden states 通常保持 $n \times d$ 的形状。
一个常见的 pre-norm block 可以按两步看:先做 attention 更新,再做 MLP 更新。令
\[X = \mathrm{Norm}(H^{(\ell)})\]attention 先把同一个状态矩阵 $X$ 投影成三组矩阵:
\[Q = XW_Q,\quad K = XW_K,\quad V = XW_V\]这里 $Q$、$K$、$V$ 仍然按位置排列。差别在于,它们来自不同的投影矩阵,服务于不同用途:$Q$ 和 $K$ 用来产生位置之间的匹配分数,$V$ 是之后要被读取和混合的信息。
接着,模型用 $QK^\top$ 一次性算出所有位置之间的匹配分数:
\[S = \frac{QK^\top}{\sqrt{d_k}} + M\]其中 $M$ 是 causal mask,用来挡住未来位置。对 $S$ 的每一行做 softmax,就得到每个位置对可见位置的读取权重:
\[A = \mathrm{softmax}(S)\]然后用这些权重混合 values:
\[C = AV\]实际模型里通常还有 multi-head attention 和一个 output projection。这里可以把它压缩成一个 attention 更新量:
\[U_{\mathrm{attn}} = C W_O\]这个更新量会通过 residual connection 加回原来的状态矩阵:
\[\widetilde{H} = H^{(\ell)} + U_{\mathrm{attn}}\]到这里,attention 做完了一件事:用矩阵投影得到比较和读取所需的表示,用 $QK^\top$ 产生跨位置的权重,再用 $AV$ 把别的位置的信息混回每个位置。
接下来是 MLP。它对每个位置的 state 做同一组变换,主要负责位置内部的特征加工。令
\[R = \mathrm{Norm}(\widetilde{H})\]一个简化的 MLP 可以写成:
\[\mathrm{MLP}(R) = \sigma(RW_1)W_2\]它先把每个 state 投影到中间维度,经过非线性函数,再投影回模型宽度 $d$。这个更新量同样加回状态流:
\[H^{(\ell+1)} = \widetilde{H} + \mathrm{MLP}(R)\]这样,一个 transformer block 就可以看成:先通过 attention 做跨位置的信息读取,再通过 MLP 做每个位置内部的特征加工,两个更新都通过 residual connection 写回同一个状态矩阵。
经过很多层这样的更新之后,模型拿某个位置的最终 hidden state 产生 vocabulary 上的 logits:
\[z_t = h_t^{(L)} W_{\mathrm{out}}\]这里 $h_t^{(L)} \in \mathbb{R}^d$ 是第 $t$ 个位置最后一层的 state,$W_{\mathrm{out}} \in \mathbb{R}^{d \times \lvert \mathcal{V} \rvert}$ 把它映射到 vocabulary 大小的空间。可以把每个 logit 理解成最终 state 和某个输出 token 方向之间的一次线性打分。softmax 之后,这些分数变成下一个 token 的概率分布:
\[p(x_{t+1}\mid x_{\le t}) = \mathrm{softmax}(z_t)\]沿着这个 forward pass 回看,线性代数贯穿了整条状态更新链。Embedding table 先把离散的 token ids 变成 $H^{(0)}$,也就是把符号放进一个可以连续计算的向量空间。每个 transformer block 接过上一层的 $H^{(\ell)}$,通过 $W_Q,W_K,W_V$ 把同一组 states 投影成三组矩阵;$QK^\top$ 产生位置之间的匹配分数;softmax 把这些分数变成读取权重;$AV$ 按权重把 value 里的信息混回每个位置。attention 的输出再投回模型宽度 $d$,作为一次更新写回状态矩阵。随后,MLP 在每个位置内部继续处理这份 state,通过 $W_1$、非线性函数和 $W_2$ 形成另一份更新量。经过很多层这样的更新之后,某个位置的最终 hidden state 再通过 $W_{\mathrm{out}}$ 被读成 vocabulary 上的 logits。宏观上,forward pass 大致就是状态矩阵 $H$ 在一系列操作之间不断被投影、比较、混合、回写和打分。
这个宏观图景只是第一层。接下来要往细处看:每一步到底在做什么,为什么这种操作合理,它改变了哪些信息,又带来了什么意义。一个 hidden state 如何用向量承载当前位置信息?矩阵乘法如何在每个位置内部重写表示?attention 如何让不同位置产生匹配并交换信息?加权混合如何把上下文写回每个位置?先从最小的对象开始,也就是一个 hidden state 为什么可以是向量。
从 token id 到向量状态
LLM 接触文本时,最开始看到的是一串 token ids。
经过 tokenizer 后,一段文本会变成整数序列:
\[x_1,x_2,\ldots,x_n\]其中每个 $x_i$ 都是 vocabulary 里的一个索引:
\[x_i \in \{1,\ldots,\lvert \mathcal{V} \rvert\}\]这些整数适合程序查表,却没有自然的几何意义。比如 token id 1234 和 token id 1235 相差 1,这个差值只说明它们在编号上相邻,不能说明两个 token 在语义、语法或用法上更接近。模型后续要做的是连续计算:加权、投影、比较、更新。整数索引本身无法承担这些操作。
Embedding lookup 给每个 token id 分配一个向量状态。设 vocabulary 大小为 $\lvert \mathcal{V} \rvert$,embedding table 是
\[E \in \mathbb{R}^{\lvert \mathcal{V} \rvert \times d}\]如果第 $i$ 个位置的 token id 是 $x_i$,模型取出第 $x_i$ 行:
\[e_i = E_{x_i} \in \mathbb{R}^d\]这个 $e_i$ 是第 $i$ 个位置进入模型主干之前的初始状态。它有 $d$ 个坐标,可以被加上更新量,可以被矩阵投影,也可以在训练中通过梯度逐步调整。对一整段长度为 $n$ 的序列,把这些向量按位置堆起来,就得到最初的状态矩阵:
\[H^{(0)} = \begin{pmatrix} e_1^\top \\\\ e_2^\top \\\\ \vdots \\\\ e_n^\top \end{pmatrix} \in \mathbb{R}^{n \times d}\]这里每一行对应一个位置的初始 hidden state。进入 transformer block 之后,模型处理的核心对象就是这份状态矩阵。在第 $\ell$ 层,整段序列的状态写成
\[H^{(\ell)} \in \mathbb{R}^{n \times d}\]第 $i$ 行可以写成 $(h_i^{(\ell)})^\top$,对应的列向量
\[h_i^{(\ell)} \in \mathbb{R}^d\]就是第 $i$ 个位置在这一层的 hidden state。随着层数增加,$h_i^{(\ell)}$ 会不断吸收当前位置的信息、上下文的信息,以及前面层已经加工出的特征。第 $i$ 行仍然对应第 $i$ 个位置,但这一行里的内容已经不再只是初始 token embedding。
这就是向量状态在 LLM 里的基本角色:它是每个位置的工作区。Embedding 先给这个位置一个初始 state;每个 transformer block 再根据上下文和当前位置的内容,往这个 state 里写入新的更新量;后续层也会从这个 state 里读出接下来计算需要的信号。
这个设计对模型很方便。每个位置都有一份同宽度的连续状态,所以 attention 的输出、MLP 的输出、residual connection 都能写回同一个对象。这个对象有坐标,所以可以被加法更新;有方向,所以可以被线性层读取;作为矩阵的一行,所以可以和整段序列一起批量计算。模型不需要把某个维度固定解释成某个具体的人类概念。训练会调整整个空间,让某些方向、某些组合方式逐渐变得有用。
看一个小例子。假设某个 hidden state 是三维的:
\[h = \begin{pmatrix} 3 \\\\ -2 \\\\ 4 \end{pmatrix}\]如果用方向
\[u_1 = \begin{pmatrix} 1 \\\\ 0 \\\\ 0 \end{pmatrix}\]那么
\[u_1^\top h = 3\]读出的就是 $h$ 在第一条坐标轴上的分量。换成一个混合了多个坐标的方向
\[u_2 = \frac{1}{\sqrt{3}} \begin{pmatrix} 1 \\\\ -1 \\\\ 1 \end{pmatrix}\]则有
\[u_2^\top h = \frac{1}{\sqrt{3}}(3 + 2 + 4) = \frac{9}{\sqrt{3}}\]读出的就是 $h$ 沿这个混合方向的响应。这个方向同时看第一和第三个坐标,也会把第二个坐标反向计入。这个例子停在一个基础事实上:给定方向 $u$,计算 $u^\top h$,就是在读 hidden state 沿这个方向有多强。
一般地,对任意
\[u \in \mathbb{R}^d\]线性读取可以写成
\[s = u^\top h_i^{(\ell)}\]从代数上看,这是对应坐标相乘再求和:
\[u^\top h_i^{(\ell)} = \sum_{r=1}^{d} u_r h_{i,r}^{(\ell)}\]如果 $u$ 是单位向量,这个值就是 $h_i^{(\ell)}$ 在 $u$ 方向上的有符号投影长度。直觉上,可以把 $h_i^{(\ell)}$ 想成当前位置的一份工作记忆,把 $u$ 想成一个读取方向。计算 $u^\top h_i^{(\ell)}$,就像是在问:这份工作记忆在这个方向上有多强的响应?
这个观察很小,但它解释了向量状态为什么适合作为模型内部表示:同一个 state 可以被很多方向读取,每个方向都给出一个不同的线性信号(图 1)。
实际模型会把很多读取方向组织成矩阵,也会把整段序列的 states 一次性送进矩阵乘法。到这一层,文章需要的对象已经齐了:一整段序列的状态矩阵 $H$,以及每一行里可以被方向读取的 hidden state。接下来,问题从“状态是什么”转向“矩阵怎样使用这些状态”。
矩阵在 LLM 里做的两类工作
有了状态矩阵 $H \in \mathbb{R}^{n \times d}$ 之后,再看“LLM 里到处都是矩阵乘法”这句话,会更具体一些。从信息流的角度看,矩阵乘法先分成两类:一类在每个 token 自己的位置上改写表示,另一类让不同 token 之间交换信息。
线性层在每个 token 内部改写表示
先看最常见的形式:
\[Y = HW,\quad W \in \mathbb{R}^{d \times m}\]把 $H$ 按行写开:
\[H = \begin{pmatrix} h_1^\top \\\\ h_2^\top \\\\ \vdots \\\\ h_n^\top \end{pmatrix}\]那么
\[HW = \begin{pmatrix} h_1^\top W \\\\ h_2^\top W \\\\ \vdots \\\\ h_n^\top W \end{pmatrix}\]第 $i$ 行的输出只由第 $i$ 行的输入决定:
\[y_i = W^\top h_i\]如果第 3 个 token 的 hidden state 改了,$HW$ 的第 3 行会变;其他行不会因为这一次乘法直接改变。所有位置共享同一个 $W$,所以每个位置都经过同一种线性变换。这个操作像是给每个 token 单独做一次相同的重新描述。
这种重新描述可以改变维度,比如把 $d$ 维 state 投到 $m$ 维;也可以改变读取方式,因为 $W$ 的每一列都可以看成一个学到的读取方向。它还可以改变用途:同一份 hidden state 可以被读成 query、key、value、MLP 中间表示,或者 vocabulary logits。
Attention 前面的
\[Q = HW_Q,\quad K = HW_K,\quad V = HW_V\]仍然属于这种 token 内变换。对每个位置来说,$W_Q$ 读出读取请求,$W_K$ 读出匹配信号,$W_V$ 读出准备输出的内容。到这一步为止,每个位置还只是在处理自己的 state。
MLP 也是这种 token 内变换:
\[\sigma(HW_1)W_2\]它对每个位置使用同一套 $W_1$ 和 $W_2$。$W_1$ 把 state 投到中间维度,非线性函数改变每个坐标的响应,$W_2$ 再投回模型宽度。up projection 和 down projection,比如 $HW_u$ 和 $HW_d$,也属于同一类:先把每个位置的表示展开到另一个维度,再压回目标维度。输出层
\[z_t = h_t^{(L)}W_{\mathrm{out}}\]读的是最后一个位置自己的 state,结果是 vocabulary 上的 logits。
所以,$HW$ 这一类矩阵乘法主要是在每个 token 内部改写表示。它让同一个位置的 state 进入新的坐标系统,产生新的可读信号,或者改变维度。它本身没有把别的位置的信息混进来。
attention 产生 token 之间的信息交互
另一类矩阵乘法出现在 attention 里。在标准 decoder-only Transformer 中,不同 token 的通信主要发生在 self-attention。$Q$、$K$、$V$ 已经由 token 内变换准备好,接下来模型需要决定两件事:每个位置应该从哪些过去位置读取信息,以及读到的信息怎样写回当前位置。
这里可以把每个位置想成同时带着三份东西。$q_i$ 是位置 $i$ 生成的读取请求,表示当前位置希望从上下文里得到哪类信号;$k_j$ 是位置 $j$ 暴露出来的匹配信号,表示它可以响应哪类读取请求;$v_j$ 是位置 $j$ 准备输出给其他位置的内容。query 和 key 做路由,value 提供被路由过去的内容。
先看 $QK^\top$ 到底乘出了什么。为了简化,先不写 multi-head,只看一个 head。此时
\[Q,K \in \mathbb{R}^{n \times d_k},\quad V \in \mathbb{R}^{n \times d_v}\]把 $Q$ 和 $K$ 按行写成:
\[Q = \begin{pmatrix} q_1^\top \\\\ q_2^\top \\\\ \vdots \\\\ q_n^\top \end{pmatrix}, \quad K = \begin{pmatrix} k_1^\top \\\\ k_2^\top \\\\ \vdots \\\\ k_n^\top \end{pmatrix}\]那么 $K^\top$ 的列就是这些 key vectors:
\[K^\top = \begin{pmatrix} k_1 & k_2 & \cdots & k_n \end{pmatrix}\]所以 $QK^\top$ 展开后是:
\[S = QK^\top = \begin{pmatrix} q_1^\top k_1 & q_1^\top k_2 & \cdots & q_1^\top k_n \\\\ q_2^\top k_1 & q_2^\top k_2 & \cdots & q_2^\top k_n \\\\ \vdots & \vdots & \ddots & \vdots \\\\ q_n^\top k_1 & q_n^\top k_2 & \cdots & q_n^\top k_n \end{pmatrix}\]这里的第 $i$ 行是
\[\begin{pmatrix} q_i^\top k_1 & q_i^\top k_2 & \cdots & q_i^\top k_n \end{pmatrix}\]这一行只使用同一个 query $q_i$,然后依次和所有 key vectors 做 dot product。第 $j$ 个位置上的数 $q_i^\top k_j$ 是位置 $i$ 对位置 $j$ 的 attention score。decoder-only 模型会用 causal mask 挡住未来位置,所以第 $i$ 行真正保留下来的是:
\[q_i^\top k_1,\ q_i^\top k_2,\ \ldots,\ q_i^\top k_i\]具体一点,假设一句话是 Alice dropped the glass because she ...。当模型处理 she 这个位置时,这一行可以给前面的 Alice、dropped、glass、because 以及当前位置自己打分;she 后面的 token 属于未来位置。mask 会在 softmax 前把未来列的 logits 变成不可选,概念上可以理解成加上 $-\infty$,所以 softmax 后这些未来位置的权重是 0。
这张 score table 可以通过一次矩阵乘法得到。不同位置之间的打分没有递归依赖,causal mask 负责约束每一行最终能选择哪些列。这个形状解释了 self-attention 为什么适合并行化:模型可以先把所有 query 和 key 算出来,再用一次大矩阵乘法同时得到所有 pairwise scores。
这些数还只是读取分数,内容还在 $V$ 里。接下来模型先做缩放和 mask:
\[L = \frac{QK^\top}{\sqrt{d_k}} + M\]然后对每一行做 softmax:
\[A = \mathrm{softmax}(L)\]softmax 是按行做的,所以 $A$ 的第 $i$ 行是一组权重:
\[a_{i1},a_{i2},\ldots,a_{ii}\]这些权重加起来是 1。它们可以理解成位置 $i$ 对过去各个位置的读取比例。比如某一层某个 head 里,位置 $i$ 对三个可见位置的权重可能是
\[a_{i1}=0.1,\quad a_{i2}=0.7,\quad a_{i3}=0.2\]这表示当前位置主要读取第 2 个位置,也少量读取第 1 个和第 3 个位置。
现在再看乘以 $V$ 以后发生了什么。把 $V$ 按行写成:
\[V = \begin{pmatrix} v_1^\top \\\\ v_2^\top \\\\ \vdots \\\\ v_n^\top \end{pmatrix}\]矩阵乘法
\[C = AV\]会把 $A$ 的每一行权重用到 $V$ 的各行上。展开看,$C$ 的第 $i$ 行是:
\[c_i^\top = a_{i1}v_1^\top + a_{i2}v_2^\top + \cdots + a_{ii}v_i^\top\]用列向量写,就是:
\[c_i = \sum_{j \le i} a_{ij}v_j\]如果沿用刚才的三个权重,混合结果就是
\[c_i = 0.1v_1 + 0.7v_2 + 0.2v_3\]这个式子把信息交换说得很清楚:新的 $c_i$ 由多个位置的 value vectors 混合得到。$a_{ij}$ 越大,位置 $j$ 的 $v_j$ 对位置 $i$ 的新内容影响越大。随后 $c_i$ 会经过 output projection,并通过 residual connection 写回位置 $i$ 的状态。到这里,位置 $i$ 的 hidden state 里就带上了其他 token 的信息。
所以 attention 里的矩阵乘法可以概括为两步。$QK^\top$ 生成一张 score table,每个 $q_i^\top k_j$ 都是位置 $i$ 对位置 $j$ 的读取打分;softmax 把这些分数变成读取权重;$AV$ 再把这些权重真正作用到 value vectors 上,把被读取位置的内容混回当前位置。$QK^\top$ 决定读谁,$AV$ 负责把读到的内容带回来。
这样看,LLM 里的矩阵乘法有清楚分工。$HW$ 形式的线性层负责改写表示:同一个 token 的 state 被送进新的坐标系统。Self-attention 负责交换信息:$QK^\top$ 先产生跨位置的读取分数,$AV$ 再按这些分数混合别的位置的信息。这个区分很重要,因为它解释了 Transformer 的信息流:大多数线性层负责改写当前 token 的表示,attention 负责让 token 之间通信。
有了这个分工,dot product 的位置也更清楚了。$q_i^\top k_j$ 是 token 间交互里的打分步骤。下一步要看的是,为什么两个向量的内积可以成为这样的匹配分数。
dot product 为什么能当作匹配分数
上一节里,$q_i^\top k_j$ 出现在 score matrix 的第 $i,j$ 个位置。它决定位置 $i$ 对位置 $j$ 的读取倾向。现在的问题是:为什么两个向量相乘再求和,能承担这样的打分任务?
dot product 同时看方向和长度
先从代数上看。对两个 $d_k$ 维向量
\[q = \begin{pmatrix} q_1 \\\\ q_2 \\\\ \vdots \\\\ q_{d_k} \end{pmatrix}, \quad k = \begin{pmatrix} k_1 \\\\ k_2 \\\\ \vdots \\\\ k_{d_k} \end{pmatrix}\]它们的 dot product 是
\[q^\top k = \sum_{r=1}^{d_k} q_r k_r\]这个式子做的事情很直接:逐个坐标相乘,再把结果加起来。如果两个向量在很多坐标上同号,而且幅度都比较大,这些乘积会累加成一个更大的正数。如果很多坐标方向相反,乘积会抵消,甚至得到负数。如果它们在主要方向上关系弱,结果会接近 0。
看一个三维例子。令
\[q = \begin{pmatrix} 2 \\\\ 1 \\\\ -1 \end{pmatrix}, \quad k_1 = \begin{pmatrix} 2 \\\\ 1 \\\\ -1 \end{pmatrix}, \quad k_2 = \begin{pmatrix} -1 \\\\ 2 \\\\ 1 \end{pmatrix}\]那么
\[q^\top k_1 = 2\cdot 2 + 1\cdot 1 + (-1)\cdot(-1) = 6\]而
\[q^\top k_2 = 2\cdot(-1) + 1\cdot 2 + (-1)\cdot 1 = -1\]同一个 query $q$ 面对两个不同的 key,给出了不同分数。$k_1$ 和 $q$ 在三个坐标上的贡献都为正,所以分数高;$k_2$ 有的坐标提供正贡献,有的坐标提供负贡献,最后得到较低的分数。这个例子展示了 dot product 作为打分函数的基本行为:它奖励方向一致的成分,压低方向冲突的成分。
从几何上看,dot product 还可以写成
\[q^\top k = \lVert q\rVert \,\lVert k\rVert \cos\theta\]其中 $\theta$ 是两个向量之间的夹角。这个形式说明,dot product 同时看两件事:方向有多接近,以及两个向量本身有多大。夹角越小,$\cos\theta$ 越大;向量 norm 越大,同样的方向对齐会产生更大的分数。attention score 使用 dot product,所以它保留了 norm 的影响;cosine similarity 会先把向量归一化,只保留方向。真实模型里 norm 的含义会受到 normalization、训练动态和具体层位置影响,不能简单解释成某个固定语义;但从机制上看,dot product 确实给了模型这两个自由度。
query/key 投影让匹配规则可学习
再往深一层看,attention score 可以写回 hidden states 之间的比较。普通的 hidden state dot product 是这个比较规则里最受限的一种形式:
\[S = HH^\top,\quad S_{ij}=h_i^\top h_j\]这也能产生一张 score table,后面仍然可以接 softmax 和 $AV$。但这套打分规则是固定的:它直接用当前表示空间里的 hidden states 做相似度比较。换成双线性形式看,普通 dot product 对应
\[h_i^\top I h_j\]也就是比较规则固定为单位矩阵 $I$。此时模型只能使用原空间里已有的表示对齐关系来决定读取倾向。
query/key 投影把这件事变成可学习的比较。attention 先把两个 hidden states 改写成不同角色:
\[q_i = W_Q^\top h_i,\quad k_j = W_K^\top h_j\]再用 dot product 打分:
\[q_i^\top k_j = (W_Q^\top h_i)^\top(W_K^\top h_j) = h_i^\top W_Q W_K^\top h_j\]令
\[B = W_Q W_K^\top\]就得到
\[q_i^\top k_j = h_i^\top B h_j\]这个式子说明,attention score 可以理解成两个 hidden states 之间的一个学出来的双线性打分。普通相似度是 $B=I$ 的特殊情况;引入 $W_Q$ 和 $W_K$ 后,模型学到的是更一般的 $B$。于是 score 从固定的表示相似度,变成模型学出来的读取兼容性:当前位置的读取请求,和候选位置的匹配信号,到底对不对得上。
这也解释了为什么 query 和 key 使用两套投影。对同一个位置 $i$,$q_i$ 和 $k_i$ 会同时存在,但它们服务于不同用途:$q_i$ 用在位置 $i$ 查询别人时,$k_i$ 用在别的位置查询 $i$ 时。一个是“我现在想找什么”,一个是“我能被什么请求找到”。两套投影把这两种角色分开。
如果 query 和 key 共用同一个投影 $W$,打分矩阵对应的比较规则会变成
\[B = WW^\top\]这种 $B$ 必然是对称且半正定的。这个约束会带来两个后果。第一,self score 会有天然优势:
\[h_i^\top WW^\top h_i = \lVert W^\top h_i\rVert^2 \ge 0\]也就是说,当前位置和自己的匹配分数会变成一个 norm square。它不一定总是最大,但它有一个结构性偏置:自己和自己相乘天然容易得到一个强候选分数。使用两套投影后,self score 变成
\[q_i^\top k_i = h_i^\top W_QW_K^\top h_i\]它衡量的是当前位置的读取请求和当前位置暴露的匹配信号是否兼容,不再只是同一个投影向量的长度。
还是看 Alice dropped the glass because she ...。模型处理 she 这个位置时,当前位置自己的 state 很重要,但它可能更需要从前面的 Alice 位置读取指代对象。she 的 query 可以表达“我现在需要一个可作为指代对象的人物实体”,Alice 的 key 可以表达“这里是一个人物实体”。如果这两个信号对得上,$q_{\text{she}}^\top k_{\text{Alice}}$ 可以高于 $q_{\text{she}}^\top k_{\text{she}}$。这说明 attention score 在做路由:当前位置可以把权重放到真正提供所需信息的位置。
第二,$WW^\top$ 是对称的,只能表达对称的比较规则。使用两套投影后,
\[B = W_Q W_K^\top\]一般不要求对称,所以
\[h_i^\top B h_j \ne h_j^\top B h_i\]这意味着当前位置对目标位置的注意力得分,和目标位置对当前位置的注意力得分可以不同。语言里的依赖关系往往带有角色差异:一个位置发出读取请求,另一个位置提供可匹配的信号。Q/K 分开后,模型不再被锁在对称的相似度比较里,可以学习更适合这种依赖关系的匹配规则。
这个形式还带来一个低秩线索。因为 $W_Q \in \mathbb{R}^{d \times d_k}$,$W_K \in \mathbb{R}^{d \times d_k}$,所以
\[\mathrm{rank}(B) = \mathrm{rank}(W_Q W_K^\top) \le d_k\]也就是说,一个 head 的 compatibility matrix 通过 $d_k$ 维中间空间来组织比较。相比一个任意的 $d \times d$ 满矩阵,这种分解把打分规则组织在更低维的结构里。这条线后面会自然连接到低秩结构、SVD、压缩和 LoRA:很多看起来很大的线性作用,可能通过更低维的结构表达出来。
从实现角度看,这个分解也很自然。先把所有位置投影成 $Q$ 和 $K$,再用 $QK^\top$ 一次性得到整张 score table。keys 可以缓存在 KV cache 里,新 token 只需要算自己的 query,再和缓存里的 keys 做矩阵乘法。更复杂的打分函数也可以存在,比如 additive attention 会用一个小网络组合 query 和 key;dot-product-of-projections 把比较规则放进线性投影和矩阵乘法里,在表达力、并行性和缓存友好性之间形成了很实用的权衡。
缩放避免 softmax 过早变尖
还有一个尺度问题。$q^\top k$ 是 $d_k$ 个乘积的和。当维度变大时,如果每个坐标的尺度差不多,求和后的数值也容易变大。一个常见的简化分析是:假设 $q_r$ 和 $k_r$ 均值为 0、方差约为 1,并且各坐标近似独立,那么
\[\mathrm{Var}(q^\top k) = \mathrm{Var}\left(\sum_{r=1}^{d_k} q_r k_r\right) \approx d_k\]所以 dot product 的标准差大约会随 $\sqrt{d_k}$ 增长。attention 里使用
\[\frac{q^\top k}{\sqrt{d_k}}\]就是为了让 score 的尺度更稳定。进入 softmax 的 logits 过大时,softmax 会过早变得很尖,少数位置权重接近 1,其他位置接近 0,梯度也会更容易变得不稳定。这里的推导是尺度上的近似分析,用来解释缩放项为什么自然出现;真实模型激活分布还会受到 normalization、权重尺度和训练动态影响(图 2)。
这样回看,dot product 能作为 attention score,原因可以说得更具体:它把 query 和 key 这两个学到的比较表示压成一个标量;代回 hidden states 后,它又是一个学出来的双线性比较规则。这个标量进入同一行 softmax,变成读取分布。下一步,attention 用这个分布去混合 values,把被选中的上下文信息写回当前位置。
从匹配分数到信息写回
读取权重混合 values
到这里,attention 还只完成了一半。$QK^\top$ 给出了一张分数表,每个元素 $s_{ij}$ 表示位置 $i$ 对位置 $j$ 的读取倾向。可是分数本身还不能直接写回 hidden state。模型还需要把同一行里的分数变成一组权重,然后用这组权重去读取真正要带回来的内容。
加入缩放和 mask 后,可以把 attention logits 写成
\[\ell_{ij} = \frac{q_i^\top k_j}{\sqrt{d_k}} + m_{ij}\]其中 $m_{ij}$ 负责可见性。对 decoder-only 模型来说,如果位置 $j$ 在位置 $i$ 的未来,$m_{ij}$ 会让这个位置在 softmax 后得到 0 权重。接下来 softmax 是按行做的:
\[a_{ij} = \frac{\exp(\ell_{ij})}{\sum_{t \in \mathcal{P}_i}\exp(\ell_{it})}\]这里的 $\mathcal{P}_i$ 表示位置 $i$ 可以看到的位置集合。这样得到的第 $i$ 行
\[a_i = \begin{pmatrix} a_{i1} & a_{i2} & \cdots & a_{in} \end{pmatrix}\]就是位置 $i$ 的读取分布。每个 $a_{ij}$ 非负,同一行的权重加起来等于 1。softmax 的意义也在这里变得具体:它让同一个 query 面对所有可见 keys 时进行相对比较。某个位置的 logit 变大,它在这一行里的权重会提高;其他位置的权重会相应被压低。attention 的选择发生在行内部。
这时 key 和 value 的分工也更清楚了。key 参与的是“能不能被当前位置找到”的匹配过程,value 参与的是“被找到以后提供什么内容”的写回过程。每个位置 $j$ 都会产生一个 $k_j$ 和一个 $v_j$。$k_j$ 用来和别的位置的 query 算分数;$v_j$ 则是位置 $j$ 准备提供给其他位置读取的向量内容。
所以 attention 的后半步是
\[C = AV\]其中
\[A \in \mathbb{R}^{n \times n}, \quad V \in \mathbb{R}^{n \times d_v}, \quad C \in \mathbb{R}^{n \times d_v}\]固定第 $i$ 行看,矩阵乘法展开成
\[c_i = \sum_{j \in \mathcal{P}_i} a_{ij} v_j\]这句话是 attention 信息流的核心:位置 $i$ 最后拿到的 context vector,是所有可见 value vectors 的加权和。权重来自 $i$ 对每个位置的读取分布,内容来自各个位置自己的 value。
看一个三维例子。假设某个位置能看到三个 value vectors:
\[v_1 = \begin{pmatrix} 1 \\\\ 0 \\\\ 2 \end{pmatrix}, \quad v_2 = \begin{pmatrix} 0 \\\\ 3 \\\\ 1 \end{pmatrix}, \quad v_3 = \begin{pmatrix} 2 \\\\ -1 \\\\ 0 \end{pmatrix}\]这一行 softmax 后得到的读取权重是
\[a_i = \begin{pmatrix} 0.1 & 0.7 & 0.2 \end{pmatrix}\]那么写回当前位置的 context vector 是
\[c_i = 0.1v_1 + 0.7v_2 + 0.2v_3 = \begin{pmatrix} 0.5 \\\\ 1.9 \\\\ 0.9 \end{pmatrix}\]这个例子只用来看清线性组合,不给单个坐标安上固定语义。重要的是整体结构:第 2 个位置权重最大,所以 $v_2$ 对 $c_i$ 的贡献最大;第 1 个和第 3 个位置仍然保留一部分影响。attention 没有在离散地选择一个 token,它是在向量空间里混合多个位置提供的信息(图 3)。
Alice dropped the glass because she …。选定的 query(高亮行)对每个可见位置打出注意力权重 $a_j=\mathrm{softmax}(q^\top k_j)$(横条),各位置的 value 据此持续汇入右侧的新状态 $c=\sum_j a_j v_j$——权重越大、流得越粗。未来位置被 causal mask 灰掉。query 自己轮换,hover 或点任意行可固定。(权重为示意值。)output projection 在 residual update 前融合 heads
单个 head 得到 $C$ 之后,还需要放回整个模型的表示空间里。多头 attention 的形状可以这样整理。
假设模型宽度是 $d$,序列长度是 $n$,输入到这一层的状态矩阵是
\[H \in \mathbb{R}^{n \times d}\]在这种写法里,矩阵的行对应序列位置。第 $i$ 行就是序列第 $i$ 个 token 在这一层的向量状态。后面写 $c_i$ 或 $o_i$,指的都是第 $i$ 行对应的 token 向量。
如果有 $h$ 个 heads,每个 head 内部会有 query、key、value 三种宽度。更一般地,可以记成 $d_q,d_k,d_v$。由于 attention score 要计算 $q_i^\top k_j$,query 和 key 必须有相同长度,所以 scaled dot-product attention 里通常取 $d_q=d_k$。很多标准 Transformer 还会进一步取
\[d_q = d_k = d_v = d_{\text{head}} = \frac{d}{h}\]后面使用这个常见设置来写。第 $r$ 个 head 有自己的三套投影矩阵:
\[W_Q^{(r)} \in \mathbb{R}^{d \times d_{\text{head}}}, \quad W_K^{(r)} \in \mathbb{R}^{d \times d_{\text{head}}}, \quad W_V^{(r)} \in \mathbb{R}^{d \times d_{\text{head}}}\]每个 head 都从完整的 $H$ 出发:
\[Q^{(r)} = H W_Q^{(r)}, \quad K^{(r)} = H W_K^{(r)}, \quad V^{(r)} = H W_V^{(r)}\]对应的形状是
\[Q^{(r)},K^{(r)},V^{(r)} \in \mathbb{R}^{n \times d_{\text{head}}}\]所以 head 的来源是 projection。第 1 个 head 用一套矩阵从 $H$ 里读出自己的 $Q,K,V$;第 2 个 head 用另一套矩阵从同一个 $H$ 里读出另一组 $Q,K,V$。进入 attention 前,$H$ 仍然保持 $n \times d$ 的整体状态矩阵;按 head 分组发生在投影输出上。
实际实现里,通常会用一个大矩阵一次性算出所有 heads 的 query:
\[W_Q \in \mathbb{R}^{d \times (h d_{\text{head}})}\] \[Q_{\text{all}} = H W_Q \in \mathbb{R}^{n \times (h d_{\text{head}})}\]这个大矩阵可以看成把多个 head 的投影矩阵横向放在一起:
\[W_Q = \begin{pmatrix} W_Q^{(1)} & W_Q^{(2)} & \cdots & W_Q^{(h)} \end{pmatrix}\]算出 $Q_{\text{all}}$ 后,再把最后一维从 $h d_{\text{head}}$ reshape 成 $h \times d_{\text{head}}$,于是得到每个 head 自己的 $Q^{(r)}$。$K$ 和 $V$ 也是同样的过程。这个 reshape 切开的是 projection output 的列。
经过各自的 score、softmax、value mixing 后,第 $r$ 个 head 输出
\[C^{(r)} \in \mathbb{R}^{n \times d_{\text{head}}}\]直觉上,一个 head 是一套独立的读取方式。它有自己的 query/key 比较规则,也有自己的 value 表示空间。不同 heads 可以学习不同的读取习惯:有的可能更关注局部结构,有的可能更关注指代关系,有的可能更关注某些格式或语法信号。机制上,它们提供的是多套并行的读取通道;至于某个 head 是否对应某种人类可命名的功能,需要具体分析模型行为。
接下来,模型把这些 head outputs 沿着特征维度拼起来。拼接发生在同一个 token 的行内:
\[c_{i,\text{cat}} = \begin{pmatrix} c_i^{(1)} & c_i^{(2)} & \cdots & c_i^{(h)} \end{pmatrix}\]把所有 token 放在一起,就是
\[C_{\text{cat}} = \mathrm{Concat}(C^{(1)}, C^{(2)}, \ldots, C^{(h)}) \in \mathbb{R}^{n \times (h d_{\text{head}})}\]concat 合理的原因在于,这些向量都是同一个位置从不同 heads 读回来的结果。它们的行索引一致,都是第 $i$ 个 token;区别只在特征维度上。concat 的作用是把多路读取结果先收集起来,形成一个更宽的向量。
因为 $d_{\text{head}}=d/h$,拼接后的宽度刚好回到模型宽度 $d$。比如 $d=12$、$h=3$ 时,每个 head 输出 $n \times 4$,三个 heads 拼起来就是
\[n \times 4 \;\;+\;\; n \times 4 \;\;+\;\; n \times 4 \quad\Longrightarrow\quad n \times 12\]拼接这一步只负责并排收集,还没有让不同 heads 的信息彼此组合。最后的 output projection 负责做这件事:
\[O = C_{\text{cat}} W_O\]其中
\[W_O \in \mathbb{R}^{(h d_{\text{head}}) \times d}, \quad O \in \mathbb{R}^{n \times d}\]在 $d_{\text{head}}=d/h$ 的设置下,$W_O$ 的形状就是 $d \times d$。这个矩阵为什么能融合 heads,可以直接从乘法展开里看到。
concat 后的矩阵可以按 head 分块写成
\[C_{\text{cat}} = \begin{pmatrix} C^{(1)} & C^{(2)} & \cdots & C^{(h)} \end{pmatrix}\]其中每个 $C^{(r)}$ 都是 $n \times d_{\text{head}}$。再把 $W_O$ 按输入来源切成 $h$ 个纵向 block:
\[W_O = \begin{pmatrix} W_O^{(1)} \\\\ W_O^{(2)} \\\\ \vdots \\\\ W_O^{(h)} \end{pmatrix}, \quad W_O^{(r)} \in \mathbb{R}^{d_{\text{head}} \times d}\]于是
\[O = C_{\text{cat}} W_O = \sum_{r=1}^{h} C^{(r)} W_O^{(r)}\]这就是跨 head 融合的具体形式。每个 head 先通过自己的 $C^{(r)}W_O^{(r)}$ 贡献一个 $n \times d$ 的结果,最后这些结果在同一个 residual width 上相加。固定序列第 $i$ 个 token 对应的那一行看:
\[o_i = \sum_{r=1}^{h} c_i^{(r)} W_O^{(r)}\]再固定输出向量的第 $m$ 个维度:
\[(o_i)_m = \sum_{r=1}^{h} \sum_{\ell=1}^{d_{\text{head}}} (c_i^{(r)})_{\ell}(W_O^{(r)})_{\ell m}\]这个展开式说明,一个输出维度可以同时接收所有 heads 的分量。第 1 个 head 读回来的局部模式、第 2 个 head 读回来的指代线索、第 3 个 head 读回来的长距离依赖,都可以通过同一个输出坐标汇合到一起。如果 $W_O$ 是单位矩阵,concat 后的各段信息会原样保留在各自位置;如果 $W_O$ 是普通可学习 dense matrix,输出的每个坐标都可以从所有 head 分量里取加权和。
用一个更小的例子看会更直接。假设只有两个 heads,每个 head 输出 2 维。对序列第 $i$ 个 token 来说,
\[c_i^{(1)} = \begin{pmatrix} x_1 & x_2 \end{pmatrix}, \quad c_i^{(2)} = \begin{pmatrix} y_1 & y_2 \end{pmatrix}\]concat 后得到
\[c_{i,\text{cat}} = \begin{pmatrix} x_1 & x_2 & y_1 & y_2 \end{pmatrix}\]这里 $(x_1,x_2)$ 来自 head 1,$(y_1,y_2)$ 来自 head 2。现在乘上一个 $4 \times 4$ 的 output projection:
\[W_O = \begin{pmatrix} w_{11} & w_{12} & w_{13} & w_{14} \\\\ w_{21} & w_{22} & w_{23} & w_{24} \\\\ w_{31} & w_{32} & w_{33} & w_{34} \\\\ w_{41} & w_{42} & w_{43} & w_{44} \end{pmatrix}\]输出仍然是这个 token 的一个 4 维向量:
\[o_i = c_{i,\text{cat}}W_O = \begin{pmatrix} z_1 & z_2 & z_3 & z_4 \end{pmatrix}\]其中每个输出分量都是四个输入分量的加权和:
\[z_1 = x_1w_{11}+x_2w_{21}+y_1w_{31}+y_2w_{41}\] \[z_2 = x_1w_{12}+x_2w_{22}+y_1w_{32}+y_2w_{42}\] \[z_3 = x_1w_{13}+x_2w_{23}+y_1w_{33}+y_2w_{43}\] \[z_4 = x_1w_{14}+x_2w_{24}+y_1w_{34}+y_2w_{44}\]这就是“融合”的具体含义。对 token $i$ 来说,输出向量 $o_i$ 的第一个分量 $z_1$ 可以同时用到 $x_1,x_2,y_1,y_2$,也就是同时用到两个 heads 的信息;$z_2,z_3,z_4$ 也是同样的结构。$W_O$ 学到的是这些分量怎样重新组合。
这里的融合仍然发生在同一个 token 的向量内部。对序列第 $i$ 个 token 对应的那一行来说,$o_i$ 只由 $c_i^{(1)},\ldots,c_i^{(h)}$ 计算出来;其他 token 的信息已经在各个 head 的 $AV$ 步骤里进入了这些 $c_i^{(r)}$。所以 $AV$ 负责跨 token 读取,$W_O$ 负责把同一 token 内多路读取结果融合成下一步使用的表示。
最后,attention 输出通过 residual connection 回到当前位置的状态里。前面的 attention 公式省略了 Transformer block 外层的 LayerNorm。在很多现代 decoder-only Transformer 里,更接近实现的形式是 pre-norm:
\[\tilde H = \mathrm{LN}(H)\] \[O = \mathrm{MHA}(\tilde H)\] \[H' = H + O\]attention 子层之后,MLP 子层通常也会用类似结构:
\[H_{\text{next}} = H' + \mathrm{MLP}(\mathrm{LN}(H'))\]LayerNorm 很值得单独讲:它作用在每个 token 自己的 feature 维度上,会改变进入 attention 和 MLP 的向量尺度,也会影响 residual stream 的稳定性。这里把它放在 block 结构里标出来,具体机制留到后面的“范数、归一化和稳定性”一节。当前这一节只看 attention 写回这一步,可以粗略写成
\[H_{\text{next}} = H + O\]这一步把“读取到的上下文信息”真正并入每个 token 的 state。前面 $QK^\top$ 的作用是决定读哪里,softmax 的作用是把读取倾向变成分布,$AV$ 的作用是把被读取的位置内容混成一个向量,$W_O$ 和 residual connection 则把这个向量写回模型继续处理的状态空间。
线性层怎样改写一个 token 的表示
上一节把 attention 的信息流讲完了:$AV$ 把其他位置的 value vectors 混到当前位置,multi-head attention 把多路读取结果并排收集,$W_O$ 再把同一个位置内部的这些结果重新组合。到这一步,位置 $i$ 的 state 已经带上了上下文信息。接下来,Transformer 还会反复使用普通线性层,在每个位置内部继续加工这份 state。
最普通的线性层可以写成
\[Y = HW\]设
\[H \in \mathbb{R}^{n \times d}, \quad W \in \mathbb{R}^{d \times m}, \quad Y \in \mathbb{R}^{n \times m}\]这里的 $H$ 是一整段序列的状态矩阵,第 $i$ 行对应序列第 $i$ 个 token 的 state。右乘 $W$ 以后,输出仍然按同样的行顺序排列。第 $i$ 行的输出是
\[y_i^\top = h_i^\top W\]这句话先给出一个很重要的边界:第 $i$ 行的输出只由第 $i$ 行的输入算出来。$HW$ 会对所有位置使用同一个 $W$,但每一行都是单独经过这套变换。第 3 行的输入会影响第 3 行的输出;第 7 行的输入会影响第 7 行的输出。一次普通线性层不会直接把第 7 行混进第 3 行。
把第 $i$ 行的某个输出坐标展开看:
\[Y_{ij} = \sum_{r=1}^{d} H_{ir}W_{rj}\]这个求和发生在 feature 维度上。固定第 $i$ 行时,模型读取的是同一个 token 的
\[H_{i1},H_{i2},\ldots,H_{id}\]然后加权得到新的 feature。相比之下,attention 的 value mixing 是
\[C_{i\ell} = \sum_{j=1}^{n} A_{ij}V_{j\ell}\]这里求和发生在位置 $j$ 上,所以第 $i$ 行会读取 $V$ 的多行。区别就在这里:普通线性层的共同维度是 feature,attention 的 $AV$ 的共同维度是 token position。前者加工当前 token 自己的向量,后者把其他 token 的 value 信息混进当前 token。
有了这个边界,再看 $W$ 本身在做什么。把 $W$ 按列写开:
\[W = \begin{pmatrix} w_1 & w_2 & \cdots & w_m \end{pmatrix}, \quad w_j \in \mathbb{R}^{d}\]于是
\[h_i^\top W = \begin{pmatrix} h_i^\top w_1 & h_i^\top w_2 & \cdots & h_i^\top w_m \end{pmatrix}\]每一列 $w_j$ 都是一条学出来的读取方向。输出的第 $j$ 个分量,就是当前位置的 hidden state 沿着 $w_j$ 这个方向的响应。一个线性层把很多这样的读取方向放在一起,于是同一个 hidden state 会被同时读成 $m$ 个线性信号。
这就是线性层的 feature mixing。它不需要改变 token 的位置,也不需要读取其他行;它只是在当前 token 的向量内部,把已有坐标重新组合成新的坐标。即使输入和输出宽度一样,$d \times d$ 矩阵也可以改变表示,因为每个输出坐标都可以由多个输入坐标共同决定。
看一个 3 维 state:
\[h_i^\top = \begin{pmatrix} 2 & -1 & 3 \end{pmatrix}\]取一个 $3 \times 3$ 矩阵
\[W = \begin{pmatrix} 1 & 0 & 2 \\\\ 2 & -1 & 0 \\\\ 1 & 1 & 1 \end{pmatrix}\]那么
\[y_i^\top = h_i^\top W = \begin{pmatrix} 3 & 4 & 7 \end{pmatrix}\]展开看:
\[y_{i,1}=2\cdot 1+(-1)\cdot 2+3\cdot 1=3\] \[y_{i,2}=2\cdot 0+(-1)\cdot(-1)+3\cdot 1=4\] \[y_{i,3}=2\cdot 2+(-1)\cdot 0+3\cdot 1=7\]这里仍然避免给单个坐标安上固定语义。重要的是结构:新的第 1 个分量同时用了旧 state 的三个分量;新的第 2 个分量也同时用了多个分量;新的第 3 个分量也是如此。dense matrix 的每一列都给出一套加权方式,输出向量里的每个坐标都是旧坐标的一次加权和。
这也解释了几类矩阵的差别。如果 $W$ 是 identity matrix,输出会等于输入。若 $W$ 是 diagonal matrix,每个坐标只会被单独缩放。dense matrix 允许不同坐标互相混合,所以它可以把原来的表示换到一组新的坐标方向里。这个“换坐标”不一定是纯旋转;普通神经网络权重还可以拉伸某些方向、压缩某些方向,甚至改变维度。
回到 attention,这个视角也能更准确地放置 $W_Q,W_K,W_V$ 的角色。它们把同一个 hidden state 读成三种不同信号:query、key、value。到
\[Q=HW_Q,\quad K=HW_K,\quad V=HW_V\]这一步为止,每个位置仍然在加工自己的 state。位置 $i$ 生成自己的 $q_i,k_i,v_i$,位置 $j$ 生成自己的 $q_j,k_j,v_j$。真正把位置连起来的是后面的 score table 和 value mixing:$QK^\top$ 让每个 query 和所有 keys 比较,softmax 得到位置之间的读取权重,$AV$ 再用这些权重混合 $V$ 的多行。也就是说,attention 里确实包含线性层;这些线性层负责准备每个 token 的读取材料,cross-token 信息流来自 attention weights 对 value rows 的加权混合。
$W_O$ 又回到线性层的角色。$AV$ 和 multi-head 已经把上下文信息带回当前位置,$W_O$ 接着在同一个 token 的 concat 向量内部做 feature mixing,把多路读取结果整理回 residual stream 使用的宽度。
MLP 里的 $W_1$ 和 $W_2$ 也是类似结构:
\[\mathrm{MLP}(H) = \sigma(HW_1)W_2\]通常 $W_1$ 会把宽度从 $d$ 扩到更大的中间维度 $d_{\mathrm{ff}}$:
\[W_1 \in \mathbb{R}^{d \times d_{\mathrm{ff}}}\]然后非线性函数 $\sigma$ 改变这些中间信号的响应,$W_2$ 再把它们投回模型宽度:
\[W_2 \in \mathbb{R}^{d_{\mathrm{ff}} \times d}\]如果去掉中间的 $\sigma$,两层线性层会合并成一层:
\[HW_1W_2 = H(W_1W_2)\]所以 MLP 的线性部分负责读取、扩展和重组特征;非线性部分让这些特征响应可以被门控和重新塑形。对每个 token 来说,MLP 不从其他 token 读取新信息,它处理的是 attention 和 residual stream 已经写入当前位置 state 里的内容。
输出层也可以用同一个视角看。最后一层某个位置的 state 是
\[h_t^{(L)} \in \mathbb{R}^{d}\]输出矩阵是
\[W_{\mathrm{out}} \in \mathbb{R}^{d \times \lvert \mathcal{V} \rvert}\]logits 写成
\[z_t = (h_t^{(L)})^\top W_{\mathrm{out}}\]这里 $W_{\mathrm{out}}$ 的每一列都对应 vocabulary 里一个 token 的输出方向。某个 logit 高,表示最终 state 在那个输出方向上的线性响应强。这样,预测下一个 token 也可以看成一组并行的线性读取。
所以,线性层改写 token 表示的方式可以概括成:它不改变 token 的行索引,只改变每一行内部的坐标表达。每个输出坐标都是旧 state 的一个加权和;很多输出坐标合在一起,就形成新的表示空间。attention 用 $AV$ 做跨 token 的加权混合,线性层用 $HW$ 做同一 token 内部的 feature mixing。两者都在做矩阵乘法,但混合的对象不同。
接下来还可以从另一个角度看矩阵乘法。刚才我们按输出列理解 $W$,把每一列看成一个读取方向。换一种展开方式,矩阵乘法也可以写成一组 outer products 的和。这个视角会把矩阵乘法、rank、SVD、压缩和 LoRA 连到同一条线上。
outer product 和矩阵乘法里的低秩结构
上一节按输出坐标看线性层:$W$ 的每一列是一条读取方向,$h_i^\top W$ 会把一个 token 的 state 读成多个新坐标。这个视角很适合解释某个输出 feature 怎么来。现在换一个角度看同一个矩阵乘法:整个输出矩阵是怎样由更简单的矩阵成分叠出来的?
矩阵乘法可以有两个观察层级。如果只看结果矩阵里的一个 entry,看到的是 row-column dot product。设
\[A \in \mathbb{R}^{m \times p}, \quad B \in \mathbb{R}^{p \times n}, \quad C=AB\]那么
\[C_{ij} = \sum_{k=1}^{p} A_{ik}B_{kj}\]这就是 $A$ 的第 $i$ 行和 $B$ 的第 $j$ 列做 dot product。dot product 视角回答的是:某一个输出数字是怎么算出来的。
如果看整个输出矩阵,可以把同一个乘法按共同维度 $k$ 重新分组。把 $A$ 按列看,把 $B$ 按行看:
\[A = \begin{pmatrix} a_1 & a_2 & \cdots & a_p \end{pmatrix}\] \[B = \begin{pmatrix} b_1^\top \\\\ b_2^\top \\\\ \vdots \\\\ b_p^\top \end{pmatrix}\]其中 $a_k$ 是 $A$ 的第 $k$ 列,$b_k^\top$ 是 $B$ 的第 $k$ 行。矩阵乘法可以写成
\[AB = \sum_{k=1}^{p} a_k b_k^\top\]每一项 $a_k b_k^\top$ 都是一个 outer product,结果是一整张矩阵。两种视角描述的是同一个结果,因为
\[\left(\sum_{k=1}^{p} a_k b_k^\top\right)_{ij} = \sum_{k=1}^{p} (a_k)_i(b_k)_j = \sum_{k=1}^{p} A_{ik}B_{kj} = (AB)_{ij}\]所以,矩阵乘法在 entry 级别看是 dot product,在 whole-matrix 级别看是 outer products 的和。
outer product 生成的矩阵结构很简单。比如
\[u = \begin{pmatrix} 1 \\\\ 2 \\\\ 3 \end{pmatrix}, \quad v = \begin{pmatrix} 4 \\\\ 5 \end{pmatrix}\]那么
\[uv^\top = \begin{pmatrix} 1 \\\\ 2 \\\\ 3 \end{pmatrix} \begin{pmatrix} 4 & 5 \end{pmatrix} = \begin{pmatrix} 4 & 5 \\\\ 8 & 10 \\\\ 12 & 15 \end{pmatrix}\]这个矩阵看起来有 6 个数,但结构很简单。第 2 行是第 1 行的 2 倍,第 3 行是第 1 行的 3 倍。所有行都沿着同一个方向变化。它只表达了一种独立的行方向。非零 outer product 生成的是 rank-1 matrix。
再看一个完整矩阵乘法。令
\[A = \begin{pmatrix} 1 & 2 \\\\ 3 & 4 \end{pmatrix}, \quad B = \begin{pmatrix} 5 & 6 \\\\ 7 & 8 \end{pmatrix}\]按普通乘法算:
\[AB = \begin{pmatrix} 19 & 22 \\\\ 43 & 50 \end{pmatrix}\]按 outer product 展开,则是
\[AB = \begin{pmatrix} 1 \\\\ 3 \end{pmatrix} \begin{pmatrix} 5 & 6 \end{pmatrix} + \begin{pmatrix} 2 \\\\ 4 \end{pmatrix} \begin{pmatrix} 7 & 8 \end{pmatrix}\]也就是
\[AB = \begin{pmatrix} 5 & 6 \\\\ 15 & 18 \end{pmatrix} + \begin{pmatrix} 14 & 16 \\\\ 28 & 32 \end{pmatrix} = \begin{pmatrix} 19 & 22 \\\\ 43 & 50 \end{pmatrix}\]这个展开方式的意义在于,它把一个矩阵作用拆成了若干个简单方向的叠加。一个 outer product 只能提供一个 rank-1 方向;多个 outer products 加起来,矩阵才逐渐获得更复杂的作用。
回到 LLM,这个视角会变得很有用。一个线性层权重看起来是一个很大的矩阵,但它的作用可以问得更细:表达这个作用需要很多独立方向,还是主要由少数方向支撑?如果少量 outer products 就能近似一个矩阵,说明这个矩阵的有效结构可能比较低秩。
这就是下一节要讲 rank 的原因。上一节关心矩阵怎样读出新的 features;这一节看到矩阵还可以拆成 rank-1 contributions 的叠加。接下来要问的是:这些 contribution 里,到底有多少个方向是真正独立的?
rank 衡量矩阵用了多少独立方向
上一节把矩阵乘法写成 outer products 的和。每个非零 outer product 只能贡献一个 rank-1 方向;多个 outer products 叠加以后,矩阵的作用才可能变复杂。rank 关心的就是这件事:这些贡献加在一起以后,矩阵真正用了多少个互相独立的方向?
先用标准列向量写法看一个矩阵作为线性映射。设
\[A = \begin{pmatrix} a_1 & a_2 & \cdots & a_n \end{pmatrix} \in \mathbb{R}^{m \times n}\]其中 $a_1,\ldots,a_n$ 是 $A$ 的列向量。对于输入
\[x = \begin{pmatrix} x_1 \\\\ x_2 \\\\ \vdots \\\\ x_n \end{pmatrix}\]矩阵乘法可以写成
\[Ax = x_1 a_1 + x_2 a_2 + \cdots + x_n a_n\]所以 $Ax$ 的所有可能结果,都落在 $A$ 的列向量张成的空间里。输入 $x$ 改变的只是这些列向量前面的系数。如果 $A$ 的很多列可以由少数几列线性组合出来,那么它看起来有很多列,实际能到达的输出方向却更少。
rank 就是这个可达输出空间的维度:
\[\operatorname{rank}(A) = \dim\left(\mathrm{span}(a_1,\ldots,a_n)\right)\]如果 $A \in \mathbb{R}^{3 \times 3}$ 的 rank 是 3,它的输出可以覆盖三维空间里的三个独立方向。如果 rank 是 2,所有输出都会被限制在某个平面里。如果 rank 是 1,所有输出都沿着同一条直线缩放。rank 越低,矩阵的作用越受限。
看一个具体矩阵:
\[M = \begin{pmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \\\\ 5 & 7 & 9 \end{pmatrix}\]从行看,第 3 行等于第 1 行加第 2 行:
\[\begin{pmatrix} 5 & 7 & 9 \end{pmatrix} = \begin{pmatrix} 1 & 2 & 3 \end{pmatrix} + \begin{pmatrix} 4 & 5 & 6 \end{pmatrix}\]所以这三行里只有两条独立行方向。对同一个矩阵,也可以从列空间看。把列记作 $c_1,c_2,c_3$,有
\[c_3 = -c_1 + 2c_2\]也就是说,第 3 列没有提供新的独立输出方向。对任意输入
\[x = \begin{pmatrix} x_1 \\\\ x_2 \\\\ x_3 \end{pmatrix}\]输出为
\[Mx = x_1c_1 + x_2c_2 + x_3c_3 = (x_1 - x_3)c_1 + (x_2 + 2x_3)c_2\]无论 $x_3$ 怎么变,最后的输出仍然只是 $c_1$ 和 $c_2$ 的线性组合。$x_3$ 会改变组合系数,但不会创造第三个独立输出方向。因此这个矩阵的 rank 是 2。
这个例子也说明了 rank 和信息压缩的关系。一个 rank-2 的 $3 \times 3$ 矩阵接受三维输入,但输出只使用二维空间。某些输入方向上的变化会被合并到同一个输出变化里;如果两个输入的差落在这个矩阵消掉的方向上,它们经过矩阵后会得到同样的结果。
回到 LLM,rank 给了我们一个比矩阵形状更细的观察方式。矩阵的 shape 告诉我们它有多少行、多少列、多少参数;rank 告诉我们这些参数最后支撑了多少独立作用方向。一个线性层可以很大,但它的有效作用可能集中在较少方向上。一个 attention head 的比较矩阵也可以写成低秩形式,因为 $W_QW_K^\top$ 的 rank 受到 head dimension 限制。
这里要保留一个重要区分:rank 是精确的代数计数。只要某个方向有非零贡献,即使贡献很弱,精确 rank 也会把它算进去。真实模型权重里更常见的现象是方向强弱不均:少数方向很强,很多方向很弱。要描述这种强弱排序,就需要下一节的 SVD。
SVD 把矩阵拆成按强弱排序的方向
rank 给的是独立方向的数量。这个数量很有用,但它还不够细。一个矩阵可以有多个独立方向,其中有些方向很强,有些方向很弱。SVD 做的事情,是把矩阵拆成若干条简单通道。
一条通道读取、缩放、再写回
先看一条通道。取一个输入方向 $v$,一个输出方向 $u$,再取一个非负强度 $\sigma$。矩阵
\[M = \sigma u v^\top\]对输入 $x$ 的作用是:
\[Mx = \sigma u(v^\top x)\]这个式子从右往左读。$v^\top x$ 先读取输入在方向 $v$ 上有多少分量,结果是一个标量。乘上 $\sigma$ 之后,这个标量被放大或缩小。最后乘上 $u$,把这个数写成输出方向上的一个向量。
所以 $uv^\top$ 的意义很具体:$v^\top$ 负责读,$u$ 负责写。外积把“沿 $v$ 读取”和“沿 $u$ 写出”合成了一条矩阵通道。SVD 把一个复杂矩阵拆成很多条这样的通道:
\[A = \sigma_1 u_1 v_1^\top + \sigma_2 u_2 v_2^\top + \cdots + \sigma_r u_r v_r^\top\]这就是
\[A = U\Sigma V^\top\]的展开形式。$V$ 收集所有输入方向,$\Sigma$ 收集所有强度,$U$ 收集所有输出方向。
SVD 通过通道求和重建矩阵
用一个小矩阵看完整过程。令
\[A = \begin{pmatrix} 3 & 3 \\\\ 1 & -1 \end{pmatrix}\]这个矩阵的行为很容易直接读出来。对输入
\[x = \begin{pmatrix} x_1 \\\\ x_2 \end{pmatrix}\]有
\[Ax = \begin{pmatrix} 3x_1 + 3x_2 \\\\ x_1 - x_2 \end{pmatrix}\]第一行读的是两个坐标的和,并放大 3 倍。第二行读的是两个坐标的差。SVD 会把这个矩阵拆成两条通道:
\[v_1 = \frac{1}{\sqrt{2}} \begin{pmatrix} 1 \\\\ 1 \end{pmatrix}, \quad u_1 = \begin{pmatrix} 1 \\\\ 0 \end{pmatrix}, \quad \sigma_1 = 3\sqrt{2}\] \[v_2 = \frac{1}{\sqrt{2}} \begin{pmatrix} 1 \\\\ -1 \end{pmatrix}, \quad u_2 = \begin{pmatrix} 0 \\\\ 1 \end{pmatrix}, \quad \sigma_2 = \sqrt{2}\]第一条通道沿同向方向 $v_1$ 读取输入,然后写到第一个输出坐标,强度是 $3\sqrt{2}$。第二条通道沿反向方向 $v_2$ 读取输入,然后写到第二个输出坐标,强度是 $\sqrt{2}$。
现在把每条通道写成矩阵。第一条:
\[\sigma_1 u_1 v_1^\top = 3\sqrt{2} \begin{pmatrix} 1 \\\\ 0 \end{pmatrix} \frac{1}{\sqrt{2}} \begin{pmatrix} 1 & 1 \end{pmatrix} = \begin{pmatrix} 3 & 3 \\\\ 0 & 0 \end{pmatrix}\]第二条:
\[\sigma_2 u_2 v_2^\top = \sqrt{2} \begin{pmatrix} 0 \\\\ 1 \end{pmatrix} \frac{1}{\sqrt{2}} \begin{pmatrix} 1 & -1 \end{pmatrix} = \begin{pmatrix} 0 & 0 \\\\ 1 & -1 \end{pmatrix}\]两条通道加起来:
\[\begin{pmatrix} 3 & 3 \\\\ 0 & 0 \end{pmatrix} + \begin{pmatrix} 0 & 0 \\\\ 1 & -1 \end{pmatrix} = \begin{pmatrix} 3 & 3 \\\\ 1 & -1 \end{pmatrix}\]这就合回了原矩阵。这里没有神秘步骤:每个 $\sigma_i u_i v_i^\top$ 都是一条可以单独执行的读写通道,矩阵 $A$ 是这些通道贡献的总和。
再用一个固定输入看通道怎么工作。取
\[x = \begin{pmatrix} 2 \\\\ 1 \end{pmatrix}\]原矩阵直接算:
\[Ax = \begin{pmatrix} 9 \\\\ 1 \end{pmatrix}\]第一条通道读取:
\[v_1^\top x = \frac{2+1}{\sqrt{2}} = \frac{3}{\sqrt{2}}\]乘上强度 $3\sqrt{2}$,再沿 $u_1$ 写回:
\[3\sqrt{2}u_1(v_1^\top x) = \begin{pmatrix} 9 \\\\ 0 \end{pmatrix}\]第二条通道读取:
\[v_2^\top x = \frac{2-1}{\sqrt{2}} = \frac{1}{\sqrt{2}}\]乘上强度 $\sqrt{2}$,再沿 $u_2$ 写回:
\[\sqrt{2}u_2(v_2^\top x) = \begin{pmatrix} 0 \\\\ 1 \end{pmatrix}\]两条通道加起来:
\[\begin{pmatrix} 9 \\\\ 0 \end{pmatrix} + \begin{pmatrix} 0 \\\\ 1 \end{pmatrix} = \begin{pmatrix} 9 \\\\ 1 \end{pmatrix}\]这和直接计算 $Ax$ 一样。SVD 的意义在这里就很具体:它把矩阵作用拆成两条正交通道。第一条通道读两个坐标的同向变化,写到第一个输出;第二条通道读两个坐标的差,写到第二个输出。矩阵的最终输出,是这些通道贡献加起来的结果。
放回 LLM 里,一个线性层权重 $W$ 可以用同样方式理解。它从 hidden state 里读取若干输入方向,把这些响应写到新的 feature directions 上。大的 singular values 对应强通道,小的 singular values 对应弱通道。比如 $W_Q$ 会把 hidden state 读成 query space;用 SVD 看这件事,就是看哪些 hidden directions 最强地影响 query directions。MLP 的 up/down projections 也可以这样看:哪些输入模式被强烈展开,哪些输出方向承接了这些变化。
下一节要看的直接后果是:如果前几个 singular values 已经占了主要强度,那么只保留前几条通道,也能保留矩阵的大部分线性作用。
低秩近似为什么能保留主要作用
SVD 的好处在于,它把矩阵拆成了一组按强弱排序的读写通道。前面那个矩阵有两条通道:
\[\begin{pmatrix} 3 & 3 \\\\ 1 & -1 \end{pmatrix} = \begin{pmatrix} 3 & 3 \\\\ 0 & 0 \end{pmatrix} + \begin{pmatrix} 0 & 0 \\\\ 1 & -1 \end{pmatrix}\]第一条通道读两个坐标的和,写到第一个输出,强度是 $3\sqrt{2}$。第二条通道读两个坐标的差,写到第二个输出,强度是 $\sqrt{2}$。两条通道都在原矩阵里,但强弱并不一样。
低秩近似的基本动作,就是保留前面几条强通道,丢掉后面较弱的通道。对这个例子来说,如果只保留第一条通道,就得到:
\[A_1 = \begin{pmatrix} 3 & 3 \\\\ 0 & 0 \end{pmatrix}\]这个 $A_1$ 是 rank-1 的,因为它的输出永远只落在一个方向上,也就是第一个输出坐标所在的方向。它仍然接受二维输入,但能写出的独立输出方向只剩一个。
对同一个输入
\[x = \begin{pmatrix} 2 \\\\ 1 \end{pmatrix}\]完整矩阵给出的结果是:
\[Ax = \begin{pmatrix} 9 \\\\ 1 \end{pmatrix}\]只保留第一条通道时,结果变成:
\[A_1x = \begin{pmatrix} 9 \\\\ 0 \end{pmatrix}\]丢掉第二条通道后,输出的第二个坐标消失了,所以结果从 $(9,1)$ 变成了 $(9,0)$。这就是近似带来的误差。它能保留主要作用,是因为这个输入下主要贡献来自第一条强通道;如果某个任务非常依赖第二条通道,那么这个 rank-1 近似就会明显损失信息。
一般地,SVD 可以写成:
\[A = \sum_{i=1}^{r} \sigma_i u_i v_i^\top\]其中 singular values 从大到小排列。保留前 $k$ 条通道,得到:
\[A_k = \sum_{i=1}^{k} \sigma_i u_i v_i^\top\]这个 $A_k$ 的 rank 最多是 $k$。它的输出只能由 $u_1,\ldots,u_k$ 这几个方向组合出来。低秩的含义可以这样理解:矩阵形状仍然可以很大,但真正允许它写出的独立输出方向变少了。
这件事和压缩直接相关。假设一个权重矩阵从 $d_{\mathrm{in}}$ 维输入映射到 $d_{\mathrm{out}}$ 维输出:
\[A \in \mathbb{R}^{d_{\mathrm{out}} \times d_{\mathrm{in}}}\]这里用列向量约定写 $y=Ax$。如果按前面 $HW$ 的行向量写法,矩阵形状会转置,核心分解方式不变。
完整矩阵需要存 $d_{\mathrm{out}}d_{\mathrm{in}}$ 个参数。用 $k$ 条通道近似时,可以写成:
\[A_k = U_k \Sigma_k V_k^\top\]其中
\[U_k \in \mathbb{R}^{d_{\mathrm{out}} \times k}, \quad \Sigma_k \in \mathbb{R}^{k \times k}, \quad V_k \in \mathbb{R}^{d_{\mathrm{in}} \times k}\]需要存的参数量大约变成:
\[k(d_{\mathrm{out}} + d_{\mathrm{in}} + 1)\]计算时也可以按通道理解:先用 $V_k^\top$ 把输入读成 $k$ 个系数,再用 $\Sigma_k$ 调整这些系数的强度,最后用 $U_k$ 写回输出空间。原来一次大的线性变换,被改写成“读出少数通道,再写回输出空间”。
放到 LLM 里,这个视角很自然。很多线性层都很大,例如 attention projection、MLP up/down projection、output projection。它们的矩阵形状很大,但如果主要作用集中在少数强通道上,就可以尝试用低秩矩阵近似原来的权重。这样做的含义很具体:用更少的读写通道模拟原来的线性层。
这个判断仍然需要验证。某条弱通道可能只在少数输入上重要,也可能影响某些很关键的预测。低秩近似提供的是一个有用的结构假设:大的矩阵也许可以用少数通道近似。实际压缩模型时,还要看困惑度、下游任务表现,以及是否需要额外微调。
这条线自然引到 LoRA。SVD compression 是把已有权重近似成低秩形式;LoRA 保留原始权重,只训练一个低秩更新量。下一节就看这个更新量为什么也可以理解成少数新的读写通道。
LoRA 为什么选择低秩更新
上一节讲的是权重压缩:已经有一个完整矩阵,然后尝试用更少的通道近似它。LoRA 处理的是微调时的另一个场景:预训练权重已经承载了大量能力,新的任务只需要在原来的线性层旁边学一份更新量。
对一个线性层,继续用列向量写:
\[y = Wx\]full fine-tuning 会直接修改整块权重。训练后的线性层可以写成:
\[y = (W+\Delta W)x = Wx + \Delta W x\]其中 $W$ 是原来的预训练权重,$\Delta W$ 是微调学出来的变化。如果 $W \in \mathbb{R}^{d_{\mathrm{out}} \times d_{\mathrm{in}}}$,那么 $\Delta W$ 也有同样的形状。对于 LLM 里的大矩阵,这意味着要训练和保存一整块同尺寸更新。
LoRA 的想法是把更新量限制成低秩形式:
\[\Delta W = BA\]其中
\[A \in \mathbb{R}^{r \times d_{\mathrm{in}}}, \quad B \in \mathbb{R}^{d_{\mathrm{out}} \times r}, \quad r \ll \min(d_{\mathrm{in}}, d_{\mathrm{out}})\]于是 forward pass 变成:
\[y = Wx + \frac{\alpha}{r}B(Ax)\]$W$ 保持冻结,训练时只更新 $A$ 和 $B$。系数 $\alpha/r$ 用来控制 LoRA 路径的整体强度,先把它看成一个缩放因子即可。
这个式子和前面的读写通道是同一件事。$A$ 先从输入 $x$ 里读出 $r$ 个数:
\[z = Ax \in \mathbb{R}^{r}\]然后 $B$ 把这 $r$ 个数写回输出空间:
\[\delta y = Bz \in \mathbb{R}^{d_{\mathrm{out}}}\]所以 LoRA 路径可以看成:
\[x \xrightarrow{\ A\ } z \xrightarrow{\ B\ } \delta y\]原始线性层给出 $Wx$,LoRA 路径给出一个额外更新 $\delta y$,最后两者相加(图 4)。
用一个具体形状看会更清楚。假设输入是三维,输出是四维,LoRA rank 取 $r=2$:
\[x \in \mathbb{R}^{3}, \quad A \in \mathbb{R}^{2 \times 3}, \quad B \in \mathbb{R}^{4 \times 2}\]把 $A$ 的两行写成两个读取方向:
\[A = \begin{pmatrix} a_1^\top \\\\ a_2^\top \end{pmatrix}\]那么
\[Ax = \begin{pmatrix} a_1^\top x \\\\ a_2^\top x \end{pmatrix}\]这一步只做两次读取:第一行读出输入沿 $a_1$ 的响应,第二行读出输入沿 $a_2$ 的响应。
再把 $B$ 的两列写成两个写入方向:
\[B = \begin{pmatrix} b_1 & b_2 \end{pmatrix}\]那么 LoRA 更新是:
\[BAx = B \begin{pmatrix} a_1^\top x \\\\ a_2^\top x \end{pmatrix} = b_1(a_1^\top x) + b_2(a_2^\top x)\]这就是两条新的读写通道。第一条通道用 $a_1^\top$ 读输入,再沿 $b_1$ 写到输出空间;第二条通道用 $a_2^\top$ 读输入,再沿 $b_2$ 写到输出空间。rank $r=2$ 的 LoRA 更新最多只能提供两条这样的独立通道。
从矩阵本身看,同一件事可以写成:
\[BA = b_1a_1^\top + b_2a_2^\top\]每一项 $b_ia_i^\top$ 都是 rank-1 的外积。$BA$ 是 $r$ 个 rank-1 更新的和,所以
\[\mathrm{rank}(BA) \le r\]这句话解释了 LoRA 的限制,也解释了它的效率。它把 $\Delta W$ 的自由度压到少数新的读写通道里。如果微调任务主要需要调整少数方向,这个约束就很有效;如果任务需要很多彼此独立的变化,$r$ 太小就会限制表达能力。
参数量也从这里来。完整更新 $\Delta W$ 需要
\[d_{\mathrm{out}}d_{\mathrm{in}}\]个参数。LoRA 只训练 $A$ 和 $B$,参数量是:
\[r d_{\mathrm{in}} + d_{\mathrm{out}} r = r(d_{\mathrm{in}} + d_{\mathrm{out}})\]如果 $d_{\mathrm{in}}=d_{\mathrm{out}}=4096$,$r=8$,完整更新需要大约一千六百七十万参数,而 LoRA 只需要:
\[8(4096+4096)=65536\]个参数。数量级差很多,因为 LoRA 学的是少数通道,完整 fine-tuning 学的是整块更新矩阵。
容易混淆的一点是:rank $r$ 只约束更新量 $\Delta W$。训练后实际使用的是
\[W + \frac{\alpha}{r}BA\]原来的 $W$ 仍然保留预训练阶段学到的高维线性作用。LoRA 加上的只是额外的低秩变化。因此 LoRA 的假设可以更准确地说成:针对某个新任务,需要新增的变化可以用少数通道表达。它约束的是变化量,原有能力由冻结的 $W$ 继续承载。
这也解释了 LoRA 和 SVD compression 的差别。SVD compression 从完整权重 $W$ 出发,尝试用少数通道模仿 $W$ 本身。LoRA 从冻结的 $W$ 出发,只训练额外的 $\Delta W$。前者是在压缩原有能力,后者是在原有能力旁边加一条轻量更新路径。
训练完成后,LoRA 更新还可以合并回原矩阵:
\[W_{\mathrm{merged}} = W+\frac{\alpha}{r}BA\]合并之后,推理时仍然是一层普通的线性层。也就是说,LoRA 在训练阶段把更新拆成两段小矩阵来学;到了部署时,这份更新可以重新并回原来的权重矩阵里。
放回 Transformer,LoRA 最常见地加在 attention 的投影矩阵上,例如 $W_Q,W_K,W_V,W_O$,也可以加在 MLP 的 projection 上。原因和前文一致:这些位置本来就是线性层,LoRA 可以给它们各自增加少数可训练读写通道。对 $W_Q$ 来说,低秩更新会改变 hidden state 生成 query 的方式;对 $W_V$ 来说,它会改变被读取的信息内容;对 $W_O$ 来说,它会改变多个 head 的输出如何写回 residual stream。
这样看,LoRA 可以看作前面几件事的合流:矩阵乘法作为线性变换,外积作为 rank-1 读写通道,rank 作为独立通道数量,SVD 提供了“少数强通道可能很重要”的视角。LoRA 把这些线性代数结构放进训练过程里,让模型在冻结大部分权重的同时,只学习一组小的、低秩的更新通道。
范数、归一化和稳定性
前面讲向量、矩阵、attention、SVD 和 LoRA 时,重点一直放在方向:一个 hidden state 沿哪些方向有响应,一个矩阵从哪些方向读取信息,再写到哪些输出方向。但在真实的 Transformer 里,只讲方向还不够。向量有长度,矩阵有放大倍数,logits 有数值尺度,梯度更新也有步长。尺度一旦漂得太远,同样的方向会产生完全不同的计算效果。
尺度会改变 dot product 和 logits
对一个 hidden state
\[h = \begin{pmatrix} h_1 \\\\ h_2 \\\\ \vdots \\\\ h_d \end{pmatrix}\]最常用的长度是 $L_2$ norm:
\[\lVert h\rVert_2 = \sqrt{h_1^2+h_2^2+\cdots+h_d^2}\]长度会直接影响 LLM 里的很多计算。前面讲 attention score 时已经见过:
\[q^\top k = \lVert q\rVert \lVert k\rVert \cos\theta\]这个式子说明,dot product 同时看方向和长度。两个向量方向一样,如果长度更大,score 也会更大。输出 logits 也类似:
\[z_j = h^\top w_j\]这里 $w_j$ 可以看成 vocabulary 里第 $j$ 个 token 的输出方向。最终 state $h$ 的长度变大,或者 $w_j$ 的长度变大,都会影响 logit 的大小。logits 进入 softmax 后,尺度会改变概率分布的尖锐程度:分数差距很大时,softmax 更容易把概率集中到少数 token 上;分数差距很小时,分布会更平。
用一个小例子看。取
\[h = \begin{pmatrix} 3 \\\\ 4 \\\\ 0 \end{pmatrix}, \quad c = \begin{pmatrix} 6 \\\\ 8 \\\\ 0 \end{pmatrix}\]这两个向量方向相同,但长度分别是 $5$ 和 $10$。如果输出方向是
\[w = \begin{pmatrix} 0 \\\\ 1 \\\\ 0 \end{pmatrix}\]那么
\[w^\top h = 4,\quad w^\top c = 8\]方向没有变,读取到的分数翻倍了。这就是尺度的意义:同一个方向上的信号强弱,会影响 dot product、logits、attention 权重和后续更新。
这时 LayerNorm 和 RMSNorm 的名字也变得更具体。它们里面的 Norm 指 normalization,也就是归一化;它和前面 vector norm 的联系在于,归一化需要先估计一个向量或一组激活的尺度,再用这个尺度重新调整输入。
LayerNorm 和 RMSNorm 稳定 token state
先把 Transformer 里的状态形状写完整一点。训练时通常有 batch 维度:
\[H \in \mathbb{R}^{B \times n \times d}\]$B$ 是 batch size,$n$ 是序列长度,$d$ 是 hidden width。某个样本、某个位置上的 hidden state 是:
\[h_{b,i} \in \mathbb{R}^{d}\]Normalization 的关键问题是:统计量沿哪个维度算?不同归一化方法的差别,主要就在这里。
BatchNorm 的典型做法是对同一个 feature 维度,在一批样本上统计均值和方差。为了简化,先把需要参与统计的 token 或样本编号成 $s=1,\ldots,m$。对第 $j$ 个 feature,BatchNorm 会算:
\[\mu_j = \frac{1}{m} \sum_{s=1}^{m}x_{s,j}\] \[\sigma_j^2 = \frac{1}{m} \sum_{s=1}^{m} (x_{s,j}-\mu_j)^2\]然后对每个样本的第 $j$ 个 feature 做标准化:
\[\hat{x}_{s,j} = \gamma_j \frac{x_{s,j}-\mu_j} {\sqrt{\sigma_j^2+\epsilon}} + \beta_j\]这套方法在 CNN 里很自然。一个 channel 往往对应某类局部视觉特征,很多图片、很多空间位置上的同一个 channel 可以放在一起估计统计量。batch 够大时,这些统计量也比较稳定。
Transformer 的语言建模场景更别扭。一个 batch 里会有不同句子、不同长度、不同位置、不同上下文。第 $i$ 个位置的 hidden state 可能在处理句首、代码缩进、长距离指代、标点,也可能在处理 padding 附近的 token。把一批样本里同一个 hidden feature 的值拉到一起算均值和方差,会让一个 token 的归一化结果依赖同 batch 里的其他序列和其他位置。
这种依赖对自回归模型尤其不自然。模型生成同一个 prompt 时,输出最好只由这个 prompt 和模型参数决定。BatchNorm 在训练时使用当前 batch 的统计量,推理时通常使用 moving average。训练和推理的统计来源不同;batch size、序列长度、padding 方式变化时,统计量也会变。语言模型常常要支持 batch size 为 1、动态长度、KV cache 增量生成,这些都让 batch-level 统计很难成为一个稳定接口。
于是 LayerNorm 把统计范围收回到单个 token 自己的 feature 向量。对第 $b$ 个样本、第 $i$ 个位置的 hidden state
\[h_{b,i} = \begin{pmatrix} h_{b,i,1} \\\\ h_{b,i,2} \\\\ \vdots \\\\ h_{b,i,d} \end{pmatrix}\]LayerNorm 先在它自己的 $d$ 个 feature 上计算均值:
\[\mu_{b,i} = \frac{1}{d} \sum_{j=1}^{d}h_{b,i,j}\]再计算方差:
\[v_{b,i} = \frac{1}{d} \sum_{j=1}^{d} (h_{b,i,j}-\mu_{b,i})^2\]这样可以得到标准化后的 token state:
\[\hat{h}_{b,i} = \frac{h_{b,i}-\mu_{b,i}} {\sqrt{v_{b,i}+\epsilon}}\]这一步已经把当前 token 的 hidden state 拉到稳定的统计尺度上。随后,LayerNorm 会接一个可学习的逐 feature 仿射变换:
\[y_{b,i} = \boldsymbol{\gamma}\odot\hat{h}_{b,i} + \boldsymbol{\beta}\]这里 $\boldsymbol{\gamma},\boldsymbol{\beta}\in\mathbb{R}^d$。它们的长度就是 hidden size $d$,每个 feature 各有一个 scale 和一个 shift。应用到整段状态
\[H \in \mathbb{R}^{B \times n \times d}\]时,同一组 $\boldsymbol{\gamma}$ 和 $\boldsymbol{\beta}$ 会 broadcast 到每个 batch、每个位置上。它们控制 feature 维度,并且被所有 token 和样本共享。一个全局标量只能统一放大或平移整条向量;长度为 $d$ 的向量可以让每个 feature 有自己的缩放和偏移。
为了看清参数是怎么共享的,先拿掉 batch 维度,只看一条序列。假设这一层有 3 个 token:
\[H = \begin{pmatrix} h_1^\top \\\\ h_2^\top \\\\ h_3^\top \end{pmatrix} \in \mathbb{R}^{3 \times d}\]每一行都是一个 $d$ 维 hidden state:
\[h_1,h_2,h_3 \in \mathbb{R}^{d}\]LayerNorm 会分别对 $h_1$、$h_2$、$h_3$ 计算自己的均值和方差,得到三条标准化后的向量:
\[\hat{h}_1,\hat{h}_2,\hat{h}_3 \in \mathbb{R}^{d}\]然后同一组 $\boldsymbol{\gamma}$ 和 $\boldsymbol{\beta}$ 作用到三条向量上:
\[y_1=\boldsymbol{\gamma}\odot\hat{h}_1+\boldsymbol{\beta}\] \[y_2=\boldsymbol{\gamma}\odot\hat{h}_2+\boldsymbol{\beta}\] \[y_3=\boldsymbol{\gamma}\odot\hat{h}_3+\boldsymbol{\beta}\]每个 token 的 hidden state 内容不同,归一化统计量也各自计算;缩放和偏移参数共享同一组。直觉上,$\gamma_j$ 表示“第 $j$ 个 feature 在这一层进入下一步计算前整体该放大多少”,这个规则对所有 token 都适用。
这里的 $\odot$ 是逐坐标相乘。写成矩阵语言时,它最多相当于一个对角缩放:
\[\boldsymbol{\gamma}\odot\hat{h} = \mathrm{diag}(\boldsymbol{\gamma})\hat{h}\]所以这里学的是每个 feature 自己的缩放系数,普通实现里存的是长度为 $d$ 的向量。这一步只做逐 feature 的缩放和平移;不同 feature 之间的混合仍然交给线性层,比如 $W_Q,W_K,W_V,W_O$ 和 MLP projection。
标准化后的向量已经有了稳定尺度,但这种稳定本身也会抹平一部分 feature 强弱。下一层未必希望每个 feature 都以同样强度进入计算。某些 feature 需要更强,某些 feature 需要更弱,某些 feature 还可能需要一个稳定的偏移。$\boldsymbol{\gamma}$ 学习归一化后的 feature 强弱,$\boldsymbol{\beta}$ 学习归一化后的 feature 基线。也就是说,先把漂移的输入拉回稳定坐标系,再在这个坐标系里学习合适的尺度和偏移。
用一个很小的例子看。假设某个 token 标准化后的向量是
\[\hat{h} = \begin{pmatrix} 1 \\\\ -1 \\\\ 0 \end{pmatrix}\]如果某一层更依赖第一个 feature,希望压低第二个 feature,可以学习
\[\boldsymbol{\gamma} = \begin{pmatrix} 2 \\\\ 0.5 \\\\ 1 \end{pmatrix}\]如果再学习一个偏移
\[\boldsymbol{\beta} = \begin{pmatrix} 0.1 \\\\ 0 \\\\ -0.2 \end{pmatrix}\]那么输出变成:
\[\boldsymbol{\gamma}\odot\hat{h} + \boldsymbol{\beta} = \begin{pmatrix} 2.1 \\\\ -0.5 \\\\ -0.2 \end{pmatrix}\]归一化负责让输入尺度稳定;$\boldsymbol{\gamma}$ 和 $\boldsymbol{\beta}$ 负责把稳定后的向量调整成下一层更容易使用的分布。LayerNorm 的统计量只来自当前 token 自己的 feature 向量。因此它在训练和推理时使用同一套计算方式,也天然适配变长序列和自回归生成。
这和 BatchNorm 形成了很清楚的分工:BatchNorm 关心一批样本在同一个 feature 上的分布;LayerNorm 关心当前 token 的整条 hidden state 尺度。Transformer 的每个 token 都有自己的上下文状态,模型更需要后者作为稳定接口。
RMSNorm 沿着同一个尺度控制目标继续简化。标准 LayerNorm 做三件事:减去均值、除以标准差、再用 $\boldsymbol{\gamma}$ 和 $\boldsymbol{\beta}$ 做逐 feature 调整。很多 decoder-only LLM 更重视进入子层前的整体尺度控制,于是可以保留“除以一个尺度”这一步,并保留逐 feature 的 $\boldsymbol{\gamma}$。RMSNorm 就是这个更轻的版本。
对某个 token state $h\in\mathbb{R}^d$,RMSNorm 先算 root mean square:
\[\mathrm{RMS}(h) = \sqrt{ \frac{1}{d} \sum_{j=1}^{d}h_j^2 + \epsilon }\]然后归一化:
\[\mathrm{RMSNorm}(h) = \boldsymbol{\gamma} \odot \frac{h}{\mathrm{RMS}(h)}\]看一个三维例子:
\[h = \begin{pmatrix} 3 \\\\ 4 \\\\ 0 \end{pmatrix}\]它的 RMS 是:
\[\mathrm{RMS}(h) = \sqrt{\frac{3^2+4^2+0^2}{3}} = \frac{5}{\sqrt{3}}\]忽略 $\boldsymbol{\gamma}$ 时,归一化后是:
\[\frac{h}{\mathrm{RMS}(h)} = \begin{pmatrix} \frac{3\sqrt{3}}{5} \\\\ \frac{4\sqrt{3}}{5} \\\\ 0 \end{pmatrix}\]这个新向量的 RMS 约等于 1。再把原向量整体放大 10 倍:
\[10h = \begin{pmatrix} 30 \\\\ 40 \\\\ 0 \end{pmatrix}\]它的 RMS 也会放大 10 倍:
\[\mathrm{RMS}(10h) = \frac{50}{\sqrt{3}}\]所以
\[\frac{10h}{\mathrm{RMS}(10h)} = \frac{h}{\mathrm{RMS}(h)}\]这就是 RMSNorm 的核心效果:一个 token state 整体变大或变小,进入子层前会被拉回相近尺度。LayerNorm 还会先减去均值,RMSNorm 直接控制均方根尺度。很多现代 decoder-only 模型使用 RMSNorm,因为它保留了主要的尺度控制,计算更简单。
有了归一化的对象和统计范围,pre-norm 的位置也更容易理解。Transformer block 反复在 residual stream 上加更新:
\[H \leftarrow H + U_{\mathrm{attn}}\]然后再做:
\[H \leftarrow H + U_{\mathrm{mlp}}\]这些更新会一层层累积。residual stream 的好处是信息可以保留下来,后面的层可以继续在同一个状态空间里读写;代价是尺度也会被一路带着走。如果某几层的更新偏大,后续层看到的 hidden states 也会变大;如果某些方向连续被放大,attention logits、MLP 激活和最终 logits 都可能跟着变得很尖。
pre-norm 的写法是:
\[X = \mathrm{Norm}(H)\] \[H' = H + \mathrm{Attention}(X)\]再做 MLP:
\[R = \mathrm{Norm}(H')\] \[H_{\mathrm{next}} = H' + \mathrm{MLP}(R)\]这里的关键是:attention 和 MLP 读到的是归一化后的输入,residual stream 自己仍然通过加法直接往后传。子层看到的输入尺度更稳定,residual path 也给前向信息和反向梯度保留了一条更直接的通路。
post-norm 会把 Norm 放在子层更新之后:
\[H' = \mathrm{Norm}(H+\mathrm{Attention}(H))\]这时 attention 读到的是未经归一化的 $H$,尺度漂移会先进入 $Q,K,V$ 投影和 attention logits,更新之后再被 Norm 拉回去。pre-norm 把尺度控制提前到子层入口,所以深层 decoder-only 模型更常采用这种结构。这个解释停留在直觉层面,但已经能说明 Norm 的位置会改变训练时的尺度管理方式。
matrix norm 和 clipping 控制拉伸与更新长度
归一化控制的是进入子层的向量尺度。接下来还要看另一件事:向量进入线性层之后,矩阵会怎样改变它的长度。
对一个向量,norm 回答的是“这个向量有多长”。对一个矩阵,最直接的问题可以换成:它会把输入向量拉长多少?如果输入是 $x$,输出是 $Wx$,那么长度变化可以写成:
\[\frac{\lVert Wx\rVert}{\lVert x\rVert}\]这个比值依赖输入方向。看一个三维例子:
\[W = \begin{pmatrix} 4 & 0 & 0 \\\\ 0 & 1 & 0 \\\\ 0 & 0 & 0.25 \end{pmatrix}\]先取三个单位方向:
\[e_1 = \begin{pmatrix} 1 \\\\ 0 \\\\ 0 \end{pmatrix}, \quad e_2 = \begin{pmatrix} 0 \\\\ 1 \\\\ 0 \end{pmatrix}, \quad e_3 = \begin{pmatrix} 0 \\\\ 0 \\\\ 1 \end{pmatrix}\]矩阵作用之后:
\[We_1 = \begin{pmatrix} 4 \\\\ 0 \\\\ 0 \end{pmatrix}, \quad We_2 = \begin{pmatrix} 0 \\\\ 1 \\\\ 0 \end{pmatrix}, \quad We_3 = \begin{pmatrix} 0 \\\\ 0 \\\\ 0.25 \end{pmatrix}\]所以沿第一条轴的输入会被放大 $4$ 倍,沿第二条轴尺度不变,沿第三条轴会缩到 $0.25$ 倍。同一个矩阵,对不同方向的放大率不同。矩阵范数想抓住的就是这种放大行为。
常用的 spectral norm 关注最坏情况:在所有单位长度输入里,哪一个方向会被矩阵拉得最长。写成公式就是:
\[\lVert W\rVert_2 = \max_{\lVert x\rVert=1} \lVert Wx\rVert\]在这个例子里,最大的放大率是 $4$,所以
\[\lVert W\rVert_2 = 4\]这和 SVD 接上了。SVD 把矩阵写成一组读写通道:
\[Wx = \sum_i \sigma_i u_i(v_i^\top x)\]每个 singular value $\sigma_i$ 都表示一条通道的放大强度。最大的 singular value $\sigma_1$ 正好等于 spectral norm:
\[\lVert W\rVert_2 = \sigma_1\]所以 singular values 不只是低秩近似里的排序工具,也是在描述矩阵对不同输入方向的尺度作用。$\sigma_1$ 说明最强方向会被放大多少;后面的 $\sigma_i$ 说明其他正交通道的放大强度(图 5)。
这和训练稳定性有关。一个线性层如果在某些方向上放大很强,后面的激活、attention logits 或梯度都可能被推到很大的尺度。很多层连续作用时,这种放大还会累积。相反,如果重要方向被压得太小,信号也可能在深层里变弱。Transformer 依靠 residual connection、normalization、初始化、学习率、优化器和有时的 gradient clipping 共同管理这些尺度。
gradient clipping 也是同一个线性代数问题。梯度可以看成一个很长的向量 $g$。参数更新大致是:
\[\theta \leftarrow \theta - \eta g\]如果 $\lVert g\rVert$ 突然很大,一步更新可能走得太远。global norm clipping 会把梯度缩放到某个阈值 $\tau$ 内:
\[g_{\mathrm{clipped}} = g\cdot \min\left(1,\frac{\tau}{\lVert g\rVert}\right)\]这样做保留了梯度方向,同时限制了更新长度。
这一节的线索可以收回到尺度管理上。vector norm 回答“向量或更新有多长”;matrix norm 回答“矩阵最多会把输入拉长多少”;BatchNorm、LayerNorm 和 RMSNorm 的差别来自统计维度;pre-norm 把尺度控制放在子层入口;singular values 描述矩阵沿不同通道的放大强度;gradient clipping 控制训练更新的长度。这些机制看起来分散,但都在处理同一件事:深层网络里,信息不仅要有正确的方向,还要有可控的大小。
回到整体:线性代数在 LLM 里到底做了什么
回到文章开头的问题,线性代数在 LLM 里是一套描述模型内部计算的语言。token id 先被放进向量空间,每个位置都有一条 $d$ 维 hidden state;之后的每一层,都在不断读取、改写、混合和重新缩放这些状态。
第一层作用是表示。向量让离散 token 变成可以连续计算的对象。一个 hidden state 可以被很多方向读取,$u^\top h$ 读出的是这个 state 沿某个方向的响应。embedding、residual stream、output logits 都建立在这个基础上:模型需要一个可以被加法更新、被线性层读取、被后续层继续加工的状态空间。
第二层作用是变换。普通线性层在每个 token 内部工作:$hW$ 或 $Wx$ 会把当前 state 重新写到一组新的坐标方向里。$W_Q,W_K,W_V$ 把同一份 hidden state 读成 query、key、value;MLP 的 projection 把 state 展开、加工、再写回;output projection 把最终 state 读成 vocabulary 上的 logits。很多矩阵乘法都发生在单个 token 内部,它们负责改变表示方式和 feature 组合。
第三层作用是交互。attention 里的 $QK^\top$ 和 $AV$ 才真正让 token 之间通信。$QK^\top$ 产生位置之间的匹配分数,softmax 把分数变成读取权重,$AV$ 按权重把其他位置的 value 混回当前位置。multi-head attention 则让模型用多组读取规则并行收集上下文,再通过 $W_O$ 汇回 residual stream。
第四层作用是尺度控制。向量的 norm 会影响 dot product、attention logits 和 output logits;矩阵的 norm 和 singular values 描述线性层会怎样放大或压缩不同方向。LayerNorm、RMSNorm、pre-norm 和 gradient clipping 都在处理同一个现实问题:深层网络里的信号需要有方向,也需要有合适的大小。
最后是结构。outer product 说明矩阵乘法可以拆成 rank-1 contributions;rank 说明一个矩阵实际用了多少独立方向;SVD 进一步把这些方向按强弱排序。低秩近似、SVD compression 和 LoRA 都来自这个观察:很多有用的线性作用可能集中在少数通道里。LoRA 训练少数新的读写通道,SVD compression 则尝试用少数已有强通道近似完整权重。
所以,“LLM 里到处都是矩阵乘法”这句话当然成立,但它只是入口。更具体地看,矩阵乘法有时是在同一个 token 内部改写表示,有时是在 attention 里产生跨 token 的路由,有时是在把多个 heads 的读取结果融合回状态空间,有时又暴露出低秩和尺度结构。线性代数真正提供的,是一套可以追踪这些动作的坐标系统:信息在哪里表示,沿什么方向被读取,经过什么矩阵被改写,在哪些位置之间流动,又以什么尺度继续传下去。
顺着这个框架再看 attention、MLP、normalization、LoRA 和压缩,它们会落到同一组可追踪的问题里:这个操作在读什么,写到哪里,改变了哪些方向,控制了什么尺度,又保留或丢掉了多少结构。线性代数的价值就在这里:它把看起来分散的模块,整理成一套关于表示、变换、交互、尺度和结构的计算语言。