oBert


obert

github.com

Introduction

直接看结果:

贡献:

  • 调研了lottery-ticket, movement pruning,magnitude and second-order pruning.
  • 介绍了一种通用的二阶剪枝方法,称为最优BERT外科医生(oBERT),支持非结构化和块剪枝,是第一种既高精度又可扩展到BERT模型维数的二阶方法
  • 二阶剪枝方法需要逆黑森的近似,这对于LLM参数计数的存储和计算是昂贵的→未来如何近似逆海森矩阵?

The Optimal BERT Surgeon (oBERT)

Generalized Second-Order Block Pruning

令$W_M=M\odot W^$其中$W^\in R^d$是一个密集模型的权重,$d$是全部权重,$M\in \{0,1\}^d$表示掩码,即剪枝,于是使用泰勒展开式得到:

考虑到$W^$优化良好,于是假定$\nabla L(W^)\approx0$,通过修剪权值子集所引起的损失的变化可以表示为:

→→→→如果不近似呢?如何推导?

其中:

在$W^*$处近似海森矩阵的方法是通过一个衰减的经验fisher信息矩阵:

$m$是用于近似黑森的梯度外积的数量

推导:

对于一个将输入向量$in\in n_{in}$映射到输出向量$o\in n_0$的网络:

与训练集对应的均方误差定义为($P$是样本数,$t^{[k]}$是期望输出,$o^{[k]}$是实际输出):

关于$W$的一阶导数是($\frac{\partial E}{\partial W}=\frac{\partial E}{\partial o} \frac{\partial o}{\partial W}$):

海森矩阵是:

考虑一个完全训练到$W*$的局部最小误差的网络,可以忽略$t^{[k]}-o^{[k]}$:

如果输出网络只有一个输出,我们可以将导数的n维数据向量$X^{[k]}$定义为:

于是海森矩阵可以写成:

如果网络是多元输出,则$X\in n×n_o$:

于是海森矩阵可以写成:

以上表明,$H$是与梯度变量$X$相关的样本协方差矩阵。对于单个输出情况,可以通过依次添加连续的“分量”计算完整的海森矩阵:

然后可以得到在$W^*$处近似海森矩阵的方法是通过一个衰减的经验fisher信息矩阵。

回到剪枝问题,识别一个给定形状的权重Q块,通过零掩蔽去除将导致最小的损失增加。这将导致以下约束优化问题:

一次剪枝一组权重Q,对于这个组中的权重,如果对其进行剪枝,则必须保证增量和原始值相同,如果不对其剪枝,则不用管它。

由拉格朗日乘数法得到权重更新:

Q权重的重要性:

高效实现

由于$\hat F^{-1}(W)$太难计算,文中使用了近似的方法。

修剪最优的权重集

在实践中,评估每组Q的显著性得分$\rho_Q$,并修剪得分最低的$\frac{s×d}{|Q|}$组权重,其中$s\in (0,1]$表示稀疏度,$d$是全部权重。

经验逆fisher矩阵计算

采用Woodbury/Sherman-Morrison(WSM) inversion formula:

得到:

实现过程

$N_B=\frac{d}{B}$表示总的块数,第一步计算:

第二步计算($\in R^{N_B}$):

第三步计算:

代码介绍

更新每组权重得分$\rho_Q$:

scores[i] = (
    (self._params[i].data.reshape(-1) ** 2).to(self._devices[i])
    / (2.0 * finv.diag() + self._eps)
).reshape(self._params[i].shape)

更新$W^*$

obs_updates[i] = (
    self._finvs[i]
    .mul(
        (param.data * (mask_diffs[i] == -1))
        .reshape(-1)
        .to(self._devices[i])
        / (self._finvs[i].diag() + self._eps)
    )
    .reshape(param.data.shape)
)

计算逆经验$Fisher$矩阵

def add_grad(self, g: Tensor):
    """
    Updates empirical Fisher inverse with a new gradient
    :param g: a collected gradient
    """
    # if 'd / B' is not integer, pad with zeros for batch calculations
    if g.numel() < self.num_blocks * self.B:
        g = torch.cat(
            [g, torch.zeros(self.num_blocks * self.B - g.numel(), device=g.device)]
        )

    # prepare grad for batch calculations
    g = g.reshape(self.num_blocks, self.B)

    # batched f_inv x g: (batch, B, B) x (batch, B) -> (batch, B)
    finv_g = torch.einsum("bij,bj->bi", self.f_inv, g)

    # scalar denominator for each batch: (batch)
    alpha = (self.m + torch.einsum("bi,bi->b", g, finv_g)).sqrt().unsqueeze(1)
    finv_g /= alpha

    # update f_inv with new outer product: (batch, B) x (batch, B) -> (batch, B, B)
    self.f_inv.baddbmm_(finv_g.unsqueeze(2), finv_g.unsqueeze(1), alpha=-1)

文章作者: ghtll
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ghtll !