上篇提到的In-Run Data Shapley不能直接应用于Adam优化器,因为Adam 的参数更新不是梯度的线性函数,本片论文对In-Run Data Shapley进行了Adam优化器适配。
Adam优化器#
Adam(Adaptive Moment Estimation)结合了两种经典优化思想:
- Momentum(动量):利用梯度的一阶矩(均值)加速收敛,抑制震荡
- RMSProp(自适应学习率):利用梯度的二阶矩(方差)对每个参数自适应调整学习率
算法流程#
输入:学习率 η(默认 0.001),衰减率 β1=0.9,β2=0.999,数值稳定项 ϵ=10−8
初始化:m0=0,v0=0,t=0
每步更新(对训练样本 zt 计算梯度 gt=∇ℓ(wt−1,zt)):
-
更新步数:t←t+1
-
更新一阶矩(梯度均值):
mt=β1mt−1+(1−β1)gt
-
更新二阶矩(梯度方差):
vt=β2vt−1+(1−β2)gt2
-
偏差修正:
m^t=1−β1tmt,v^t=1−β2tvt
-
参数更新:
wt=wt−1−η⋅v^t+ϵm^t
各组件作用#
一阶矩 mt(动量)
mt 是梯度的指数移动平均,记录了梯度的”方向趋势”。相比 SGD 直接用当前梯度 gt 更新,动量可以:
- 在一致方向上加速(累积历史梯度)
- 在震荡方向上抑制(正负梯度相互抵消)
二阶矩 vt(自适应缩放)
vt 是梯度平方的指数移动平均,记录了每个参数梯度的”历史波动幅度”。更新时除以 v^t 实现自适应:
- 梯度波动大的参数(v^t 大)→ 学习率被缩小,更新更保守
- 梯度波动小的参数(v^t 小)→ 学习率被放大,更新更激进
这使得 Adam 对不同参数能自动调节步长,特别适合稀疏梯度或参数尺度差异大的场景。
偏差修正
由于 m0=v0=0,训练初期 mt 和 vt 偏向零。除以 (1−βt) 进行修正:
- t 小时,βt≈1,修正系数 1−βt1 很大,补偿零初始化的偏差
- t 大时,βt→0,修正系数趋于 1,修正效果消失
ϵ 的作用
ϵ(通常 10−8)加在分母中防止除零,保证数值稳定。
闭式Adam Shapley#
第一步不变,对局部效用函数做一阶Taylor展开:
U(t)(S)=ℓ(w~t+1(S),zval)−ℓ(wt,zval)=U(1)(t)(S)∇ℓ(wt,zval)⋅(w~t+1(S)−wt)+高阶项
第二步:代入Adam更新规则
已知 w~t+1(S)=wt−ηt∑z∈Svt(z)+ϵm^t(z),对于子集 S 和加入样本 z 后的边际贡献:
U(1)(t)(S∪z)−U(1)(t)(S)=−ηt∇ℓ(wt,zval)⋅v^t(z)+ϵm^t(z)
这一步需要假设Adam的动量状态 mt 和方差状态 vt是固定值(常量)
可以看出一阶近似下,边际贡献与子集 S 无关,因此:
ϕz(U(1)(t))=−ηt∇ℓ(wt,zval)⋅vt+ϵmt
最后利用Shapley的可加性,全局Shapley值为各步之和:ϕz(U)≈∑t=0T−1ϕz(U(1)(t))。
计算优化:Linearized Ghost Approximation#
思路:将 Adam 的非线性更新方向线性化,使其可表示为梯度的线性组合
将乘积右边分为m^t(z)和v^t(z)+ϵ1,对于第一项展开可得:
m^t=1−β1tβ1mt−1+1−β1t1−β1gt=Cm1mt−1(z)+Cm2gt(z)
对于第二项,做一阶泰勒展开:
首先,将 v^t 分解为历史部分和当前梯度扰动:
v^t=1−β2tβ2vt−1+(1−β2)gt2≈v^t−1+Cv∇ℓ(wt,z)2
其中 Cv=1−β2t1−β2。令 δ=Cv∇ℓ(wt,z)2 为扰动量,则 v^t≈v^t−1+δ。
定义函数 f(x)=x+ϵ1,在展开点 x0=v^t−1 处:
f(x0)=v^t−1+ϵ1=At(z)1
其中 At(z)=v^t−1(z)+ϵ 是基于历史状态的预条件项。
对 f(x) 求导:
f′(x)=−2(x+ϵ)2x1
在 x0 处近似(假设 x≈x+ϵ 用于导数量级估计):
f′(x0)≈−2At(z)31
由一阶 Taylor 展开 f(x0+δ)≈f(x0)+f′(x0)⋅δ,代入 δ=Cv∇ℓ(wt,z)2:
v^t+ϵ1≈At(z)1−2At(z)3Cv∇ℓ(wt,z)2
最后相乘,去掉 O(gt2) 和 O(gt3) 的高阶项得到:
ϕz≈History Term−ηt∇ℓval⋅AtCm1mt−1+Linear Gradient Term−ηt∇ℓval⋅(AtCm2⊙gt(z))
对于History Term还是要计算验证梯度∇ℓval,但是只需计算一次
Linear Gradient Term可以使用Ghost Dot-Product 高效计算。