1 Batch Norm

1.1 ICS:协变量偏移

Covariate Shift(协变量偏移) 是指训练集和测试集输入分布不一致,导致模型泛化能力差。

作者将这个概念推广到了神经网络内部,提出了 Internal Covariate Shift (ICS)

在深度网络中,激活函数会改变各层数据的分布。随着网络层数加深,这种分布的变化会被不断累积和放大。 也就是说 ICS 指的是层与层之间数据分布的不稳定

1.2 饱和激活函数梯度消失

饱和激活函数(如 Sigmoid 和 Tanh):当数据分布发生偏移时,很多神经元的输出会落入激活函数的 饱和区 所以梯度会衰减消失

1.3 BN:解决梯度消失

  • 将每层输入按照特征channel拉回均值 0、方差 1 的标准分布
  • 然后,加入缩放和平移变量 γ 和 β:保证每一次数据经过归一化后还保留原有学习来的特征,同时又能完成归一化操作,加速训练。 这两个参数是可学习的参数。

1.3 作用

  • 允许较大 lr
  • 减弱对初始化的依赖性
  • 让数值更稳定
  • 轻微正则化作用(相当于加 noise,类似 Dropout)

1.4 bs 太大/太小会怎样

  1. 太大
    • OOM
    • 需要跑更多 epoch(更新次数减少,为了达到同样的精度,大 Batch 通常需要跑更多的轮次(Epoch)或者配合更大的学习率,甚至有时最终泛化效果还不如小 Batch。)
    • 直接固定了梯度方向(趋向于全量数据了),很难更新,会直接落入局部最优/鞍点
      • 小 bs 带有一定随机噪声,有几率从局部最优跳出
  2. 太小
    • 算出来的均值和方差具有随机性,不能反映真实分布

    不方便用大 bs 咋办

    1. LN
    2. Group Norm(把 channel 分组,组内求 mean、std)

2 Layer Norm

2.1 公式

和 BN 一样的公式,不过是在 seq 维度进行。

2.2 手撕

 
import torch
import torch.nn as nn
 
class MyLayerNorm(nn.Module):
    def __init__(self, hidden_dim, eps=1e-5):
        super().__init__()
        # gamma 和 beta 是每个特征维度一个,初始化为 1 和 0
        self.gamma = nn.Parameter(torch.ones(hidden_dim))
        self.beta = nn.Parameter(torch.zeros(hidden_dim))
        self.eps = eps
 
    def forward(self, x):
        # x shape: [Batch, Seq_len, Hidden_dim]
        # 在最后一个维度计算均值和方差
        mean = x.mean(dim=-1, keepdim=True)
        # unbiased=False 表示使用总体方差(分母为 N),这是标准做法
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        
        # 归一化并线性变换
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out
        
def test_ln():
	b = 2
	s = 4
	h = 8
	x = torch.randn(b,s,h)
	ln = MyLayerNorm(h)
	output = MyLayerNorm(x)
	print(x.shape)
	print(output.shape)
	
if __name__=="__main__":
	test_ln()
 
 
class DxdLayerNorm(nn.Module):
	def __init__(self,hidden_dim,eps = 1e-5):
		super.__init__()
		self.gamma = nn.Parameter(torch.ones(hidden_dim)) 
		self.beta = 
		self.eps = eps
	
	def forward(self,x):
		mean = x.mean(dim = -1,keepdim = True)
		var = x.var(dim = -1,keepdim = True)
		out = (x-mean)/torch.sqrt(var+self.eps)
		out = out*self.gamma+self.beta
		return out

3 RMSNorm

3.1 公式

与layerNorm相比,RMS Norm的主要区别在于去掉了减去均值的部分

3.2 好处

  • 不用减均值,效率提升
  • 实现了与 LayerNorm 相当的性能
  • 减少了对均值的依赖,适用于不同的输入分布

3.3 手撕

 
import torch
import torch.nn as nn
class DXDRMSNorm(nn.module):
	def __init__(self,hidden_dim,eps = 1e-5):
		super.__init__()
		self.gamma = nn.Parameter(torch.ones(hidden_dim))
		self.eps = eps
	def forward(self,x):
		RMS = torch.sqrt(torch.mean(x.pow(2),dim = -1,keepdim = True)+eps)
		output = x/RMS
		output = output*self.gamma
		return output
 
def testRMS():
	b = 4
	s = 8
	h = 2
	x = torch.randn(b,s,h)
	rms_norm = DXDRMSNorm(h)
	output = DXDRMSNorm(x)
	
	print(x.shape)
	print(output.shape)
	print("params:",list(rms_norm.parameters()))
 
if __name__ == "__main__":
	testRMS()	

4 PostNorm 与 PreNorm

4.1 公式

4.2 优劣分析

  • from 苏神
    • Pretraining 中,Pre Norm和Post Norm都能做到大致相同的结果,但是Post Norm的Finetune 效果明显更好
    • Pre Norm更容易训练,因为Post Norm要达到自己的最优效果,不能用跟Pre Norm一样的训练配置(比如Pre Norm可以不加 Warmup 但 Post Norm 通常要加)****
    • 一个 L 层的 Pre Norm 模型,其实际等效层数不如 L 层的 Post Norm 模型,而层数少了导致效果变差了。(Pre Norm结构无形地增加了模型的宽度而降低了模型的深度

5 MSE

5.1 公式

5.2 手撕

 
def mse(y_true,y_pred):
	sqared_err = (y_true-y_pred)**2
	return np.mean(squared_err)
 
y_true = np.array([2.0,4.0,5.0])
y_pred = np.array([3.0,4.4,5.5])
 
print(mse(y_true,y_pred))
 

6 CE(交叉熵)

6.1 公式

6.2 手撕

import numpy as np
 
def ce(y_true_onehot,y_pred):#in: s*v
	#softmax
	exps = np.exp(y_pred-np.max(y_pred,axis = 1,keepdims = True))
	softmax_out = exps/np.sum(exps,axis = 1,keepdims = True)
	
	eps = 1e-7
	clipped = np.clip(softmax_out,eps,1-eps)
	
	ce= -np.sum(y_true_onehot*np.log(clipped),axis = 1)
	
	return ce
 
y_pred = np.array([[10, 2, 1], [1, 5, 2]], dtype=float)
y_true_onehot = np.array([[1, 0, 0], [0, 1, 0]])
 
print(ce(y_pred, y_true_onehot))
 

7 手撕 GAE & PPO Loss

7.1 动作价值函数 Q(s, a) 和状态价值函数 V(s)

  • Q(s, a):在状态 s 执行动作 a 后,所能获得的期望累积奖励
  • V(s):在状态 s,遵循当前策略所能获得的期望累积奖励。可以理解为在这个状态下,平均能得多少分。

7.2 优势函数 A(s, a)

优势函数是GAE的核心。它的定义非常简单

A(s, a) = Q(s, a) - V(s)

Q:当前状态下做动作 a 的真实价值(根本不能知道)

V:当前状态的平均预期价值(底线)

所以要用一个代替品,根据贝尔曼方程,一个动作的价值等于:它带来的即时奖励 + 之后状态的价值

	$$Q(s_t, a_t) = \mathbb{E}[r_t + \gamma V(s_{t+1})]$$

	$$A(s, a) \approx \underbrace{r_t + \gamma V(s_{t+1})}_{\text{实际表现的比预期好吗?}} - \underbrace{V(s_t)}_{\text{原本的底线预期}}$$

A(s, a) = Q(s, a) - V(s)所表达的含义是:在状态s下,执行动作a比遵循当前策略的平均行为要好多少

  • A(s, a) > 0:这个动作比平均动作好,应该被鼓励。
  • A(s, a) < 0:这个动作比平均动作差,应该被避免。

如果我们能准确地知道每个状态动作对的优势值 A(s, a),策略优化就变得非常简单:更多地选择优势为正的动作,避免优势为负的动作。

问题是,在真实环境中,我们无法直接知道 Q(s, a) 和 V(s) 的真实值,只能通过采样(与环境交互)来估计它们。怎么估呢?有两个常用但各有缺陷的方法:

  • 方法A (看全程结果): 从当前动作开始,一直算到游戏结束的总奖励(蒙特卡洛)。优点:无偏(理论上准)。缺点:方差巨大(结果受后面随机性影响太大,不稳定)。
  • 方法B (只看下一步): 用当前奖励 + 对下一个状态价值的估计 - 当前状态价值(一步TD误差)。优点:方差小(只受一步随机性影响)。缺点:有偏(依赖的估计本身可能不准,且忽略了更远的收益)。

7.3 GAE

  • 原理和公式 把看 n 步的估计结果,加权混合 GAE 计算出的优势值 是从当前时刻 到序列结束的所有 的加权和:

在手撕代码时,为了避免 的复杂度,我们通常使用这个等价的递归式(从后往前算):

  • 为什么公式里是 连在一起?

    因为 物理折扣(未来的钱不如现在的值钱),而 信度折扣(越远的步数,我们对当前动作的“功劳认定”就越不确定)。

 
import torch
 
def compute_gae(rewards,values,next_values,masks,gamma=0.99,lab=0.95):
	""" rewards: [T, B] 
	values: [T, B] (当前状态的价值) 
	next_values: [T, B] (下一状态的价值) 
	masks: [T, B] (done 信号,1 为未结束,0 为结束) 
	"""
	advantages = torch.zeros_like(rewards)
	last_gae_lam = 0
	
	for t in reversed(range(len(rewards))):
		delta = rewards[t] + gamma * next_values[t] * masks[t] - values[t]
		
		advantages[t] = last_gae_lam = delta + gamma * lam * masks[t] * last_gae_lam
		
	return advantages, advantages + values
	

7.4 PPO Loss

def ppo_loss(old_log_probs, new_log_probs, advantages, eps=0.2):
    """
    old_log_probs: 采样时旧策略的 log_prob [N]
    new_log_probs: 当前更新策略的 log_prob [N]
    advantages: 计算好的优势函数 [N]
    """
    # 1. 计算概率比率 ratio = exp(new_log_prob - old_log_prob)
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    # 2. 计算两部分损失
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - eps, 1.0 + eps) * advantages
    
    # 3. 取最小值并加负号(因为是要最大化奖励,而优化器是做梯度下降)
    # PPO 还会加上 Entropy Loss 和 Value Loss,这里是核心的 Policy Loss
    loss = -torch.min(surr1, surr2).mean()
    
    return loss