RMSNorm的梯度反传
前向传播
给定输入向量\(\boldsymbol{x} = (x_1, x_2, ..., x_n)\),RMSNorm 的输出\(\boldsymbol{y} = (y_1, y_2, ..., y_n)\)计算方式如下:
\(y_i = \frac{x_i}{RMS(\boldsymbol{x})} * w_i\)
其中,\(RMS(\boldsymbol{x}) = \sqrt{\frac{1}{n} * \sum_{j=1}^{n}(x_j^2) + \epsilon}\)
为了方便求导,我们可以将 \(y_i\) 的计算分解为几个步骤:
- 计算\(\boldsymbol{x}\)的平方和的均值: \(a = \frac{1}{n} * \sum_{j=1}^{n}(x_j^2)\)
- 计算 RMS 值: \(rms = \sqrt{a + \epsilon}\)
- 对输入进行归一化并缩放: \(y_i = (x_i / rms) * w_i\)
反向传播
- \(\frac{\partial a}{\partial x_i} = (2x_i)/n\)
- \(\frac{\partial rms}{\partial a} = \frac{1}{2\sqrt{a + \epsilon}} = \frac{1}{2\ rms}\)
- \(\frac{\partial rms}{\partial x_i} = \frac{\partial a}{\partial x_i} * \frac{\partial rms}{\partial a} = \frac{x_i}{n\ *\ rms}\)
计算\(\frac{\partial y_j}{\partial x_i}\)
case 1: \(j=i\)
\(\frac{\partial y_i}{\partial x_i} = w_i\frac{rms - x_i * \frac{x_i}{n\ * \ rms}}{rms^2} = w_i(\frac{1}{rms} - \frac{x_i^2}{n\ *\ rms^3})\)
case 2: \(j \ne i\)
\(\frac{\partial y_j}{\partial x_i} = -\frac{x_j}{rms^2} * \frac{x_i}{n\ *\ rms} * w_j = -\frac{w_j x_i x_j}{n\ *\ rms^3}\)
\(\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial x_i} + \sum_{j\ne i}\frac{\partial L}{\partial y_j}\frac{\partial y_j}{\partial x_i} \\= \frac{\partial L}{\partial y_i}w_i(\frac{1}{rms} - \frac{x_i^2}{n\ *\ rms^3}) - \sum_{j\ne i}\frac{\partial L}{\partial y_j}\frac{w_j x_i x_j}{n\ *\ rms^3}\\= \frac{\partial L}{\partial y_i}w_i(\frac{1}{rms} - \frac{x_i^2}{n\ *\ rms^3}) - \frac{x_i}{n\ *\ rms^3}(\sum_{j}\frac{\partial L}{\partial y_j}w_j x_j + \frac{\partial L}{\partial y_i}w_i x_i)\)
第二项和最后一项正好可以消掉,得:
\(\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y_i}w_i\frac{1}{rms} - \frac{x_i}{n\ *\ rms^3}\sum_{j}\frac{\partial L}{\partial y_j}w_j x_j\\=\frac{1}{rms}(dy_i*w_i - \frac{x_i}{n\ *\ rms^2}\sum_{j}\frac{\partial L}{\partial y_j}w_j x_j)\)
写成向量形式:
\(d\boldsymbol x = \frac{1}{rms}(d\boldsymbol{y} * \boldsymbol{w} - \frac{1}{n\ *\ rms^2}(\boldsymbol{x} * (d\boldsymbol{y} * \boldsymbol w) \cdot \boldsymbol{x}))\)
其中:* 表示逐元素相乘, \(\cdot\)表示向量点乘。