3倍加速LLM推理——推测解码(Speculative Decoding)
推测解码(Speculative Decoding,也有人译为“投机采样”,个人认为意译为“拒绝采样”更好一些)是Google[1]和DeepMind[2]在2022年同时发现的大模型推理加速方法。它可以在生成效果无损的前提下,获得3倍以上的加速比。GPT-4泄密报告也提到OpenAI线上模型推理使用了它。
基本思想:小模型打草稿,大模型纠正
- 打草稿:Speculative Decoding引入了一个与原本LLM结构相同但参数量小得多的近似模型。推理过程中,由小模型先用自回归的方法串行生成\(\gamma\)个token;
- 纠正:再将这\(\gamma\)个token一起送入大模型中进行推理(验证),对于小模型输出的某个token \(x_{i}\),若大模型预测的概率\(p(x_i)\)大于小模型预测的概率\(q(x_i)\),则接受\(x_i\)作为当前位置的输出;否则,大模型就以\(1-\frac{p(x_i)}{q(x_i)}\)的概率概率拒绝\(x_i\)作为当前位置的输出,并在\(p(x)>q(x)\)的区域重新采样\(x_j\)作为纠正后的输出。
加速原理:
在Speculative Decoding中,串行输出token的任务主要由小模型承担了,大模型的“纠正”过程由于可以一次并行处理\(\gamma\)个token,前向推理的次数大大减少。(与自回归的LLM可以高效并行训练的原理相同)
无偏估计:为什么说Speculative Decoding的加速过程是“无损”的
论文[1]在附录A1.1中证明了通过从\(p(x)\)和\(q(x)\)进行Speculative Decoding所得到的token的分布与仅从\(p(x)\)进行采样所得到的token的分布是相同的。有趣的是,这一结论对于任意分布的\(p(x)\)和\(q(x)\)均成立。
前提:
Speculative Decoding的过程以以下方法进行:
若\(p(x')>q(x')\),则模型以1的概率接受\(x'\)作为最终解码结果;若\(p(x')<q(x')\),则模型以\(\frac{p(x')}{q(x')}\)的概率接受\(x'\)作为最终解码结果;即小模型输出的\(x'\)被接受的概率为\(\min(1,\frac{p(x')}{q(x')})\)。
若小模型输出被拒绝,则以归一化后的\(p(x)-q(x)\)作为概率分布(记作\(p'(x)\)),在\(p(x)>q(x)\)的区域重新采样得到作为纠正后的输出。有\(p^{\prime}(x)=\operatorname{norm}(\max (0, p(x)-q(x)))=\frac{p(x)-\min (q(x), p(x))}{\sum_{x^{\prime}}\left(p\left(x^{\prime}\right)-\min \left(q\left(x^{\prime}\right), p\left(x^{\prime}\right)\right)\right)}=\frac{p(x)-\min (q(x), p(x))}{1-\sum_{x^{\prime}}\min \left(q\left(x^{\prime}\right), p\left(x^{\prime}\right)\right)}\)
证明:
\(P\left(x=x^{\prime}\right)=P\left(\right.\) guess accepted, \(\left.x=x^{\prime}\right)+P\left(\right.\) guess rejected, \(\left.x=x^{\prime}\right)\)
其中,
\(P\left(\right.\) guess accepted, \(\left.x=x^{\prime}\right)=q\left(x^{\prime}\right) \min \left(1, \frac{p\left(x^{\prime}\right)}{q\left(x^{\prime}\right)}\right)=\min \left(q\left(x^{\prime}\right), p\left(x^{\prime}\right)\right)\)
小模型的输出被接受的概率\(\beta=E_{x \sim q(x)} \begin{cases}1 & q(x) \leq p(x) \\ \frac{p(x)}{q(x)} & q(x)>p(x)\end{cases}=E_{x \sim q(x)} \min \left(1, \frac{p(x)}{q(x)}\right)=\sum_{x'} \min (p(x'), q(x'))\),代入\(p'(x)\)的表达式,有\(p'(x)=\frac{p(x)-\min (q(x), p(x))}{1-\beta}\)
则\(P\left(\right.\) guess rejected, \(\left.x=x^{\prime}\right)=(1-\beta) p^{\prime}\left(x^{\prime}\right)=p\left(x^{\prime}\right)-\min \left(q\left(x^{\prime}\right), p\left(x^{\prime}\right)\right)\)
故:\(P\left(x=x^{\prime}\right)=p(x')\)
证毕,As desired!