obert
Introduction
直接看结果:
贡献:
- 调研了lottery-ticket, movement pruning,magnitude and second-order pruning.
- 介绍了一种通用的二阶剪枝方法,称为最优BERT外科医生(oBERT),支持非结构化和块剪枝,是第一种既高精度又可扩展到BERT模型维数的二阶方法
- 二阶剪枝方法需要逆黑森的近似,这对于LLM参数计数的存储和计算是昂贵的→未来如何近似逆海森矩阵?
The Optimal BERT Surgeon (oBERT)
Generalized Second-Order Block Pruning
令$W_M=M\odot W^$其中$W^\in R^d$是一个密集模型的权重,$d$是全部权重,$M\in \{0,1\}^d$表示掩码,即剪枝,于是使用泰勒展开式得到:
考虑到$W^$优化良好,于是假定$\nabla L(W^)\approx0$,通过修剪权值子集所引起的损失的变化可以表示为:
→→→→如果不近似呢?如何推导?
其中:
在$W^*$处近似海森矩阵的方法是通过一个衰减的经验fisher信息矩阵:
$m$是用于近似黑森的梯度外积的数量
推导:
对于一个将输入向量$in\in n_{in}$映射到输出向量$o\in n_0$的网络:
与训练集对应的均方误差定义为($P$是样本数,$t^{[k]}$是期望输出,$o^{[k]}$是实际输出):
关于$W$的一阶导数是($\frac{\partial E}{\partial W}=\frac{\partial E}{\partial o} \frac{\partial o}{\partial W}$):
海森矩阵是:
考虑一个完全训练到$W*$的局部最小误差的网络,可以忽略$t^{[k]}-o^{[k]}$:
如果输出网络只有一个输出,我们可以将导数的n维数据向量$X^{[k]}$定义为:
于是海森矩阵可以写成:
如果网络是多元输出,则$X\in n×n_o$:
于是海森矩阵可以写成:
以上表明,$H$是与梯度变量$X$相关的样本协方差矩阵。对于单个输出情况,可以通过依次添加连续的“分量”计算完整的海森矩阵:
然后可以得到在$W^*$处近似海森矩阵的方法是通过一个衰减的经验fisher信息矩阵。
回到剪枝问题,识别一个给定形状的权重Q块,通过零掩蔽去除将导致最小的损失增加。这将导致以下约束优化问题:
一次剪枝一组权重Q,对于这个组中的权重,如果对其进行剪枝,则必须保证增量和原始值相同,如果不对其剪枝,则不用管它。
由拉格朗日乘数法得到权重更新:
Q权重的重要性:
高效实现
由于$\hat F^{-1}(W)$太难计算,文中使用了近似的方法。
修剪最优的权重集
在实践中,评估每组Q的显著性得分$\rho_Q$,并修剪得分最低的$\frac{s×d}{|Q|}$组权重,其中$s\in (0,1]$表示稀疏度,$d$是全部权重。
经验逆fisher矩阵计算
采用Woodbury/Sherman-Morrison(WSM) inversion formula:
得到:
实现过程
$N_B=\frac{d}{B}$表示总的块数,第一步计算:
第二步计算($\in R^{N_B}$):
第三步计算:
代码介绍
更新每组权重得分$\rho_Q$:
scores[i] = (
(self._params[i].data.reshape(-1) ** 2).to(self._devices[i])
/ (2.0 * finv.diag() + self._eps)
).reshape(self._params[i].shape)
更新$W^*$
obs_updates[i] = (
self._finvs[i]
.mul(
(param.data * (mask_diffs[i] == -1))
.reshape(-1)
.to(self._devices[i])
/ (self._finvs[i].diag() + self._eps)
)
.reshape(param.data.shape)
)
计算逆经验$Fisher$矩阵
def add_grad(self, g: Tensor):
"""
Updates empirical Fisher inverse with a new gradient
:param g: a collected gradient
"""
# if 'd / B' is not integer, pad with zeros for batch calculations
if g.numel() < self.num_blocks * self.B:
g = torch.cat(
[g, torch.zeros(self.num_blocks * self.B - g.numel(), device=g.device)]
)
# prepare grad for batch calculations
g = g.reshape(self.num_blocks, self.B)
# batched f_inv x g: (batch, B, B) x (batch, B) -> (batch, B)
finv_g = torch.einsum("bij,bj->bi", self.f_inv, g)
# scalar denominator for each batch: (batch)
alpha = (self.m + torch.einsum("bi,bi->b", g, finv_g)).sqrt().unsqueeze(1)
finv_g /= alpha
# update f_inv with new outer product: (batch, B) x (batch, B) -> (batch, B, B)
self.f_inv.baddbmm_(finv_g.unsqueeze(2), finv_g.unsqueeze(1), alpha=-1)