SPS方法:用逆强化学习缓解大模型强化学习中的概率挤压效应
1. 项目概述与问题根源
如果你最近在折腾大语言模型的强化学习微调,尤其是想让模型在数学推理、代码生成这类需要“开脑洞”的任务上表现更好,那你很可能已经遇到了一个令人头疼的瓶颈:模型学“乖”了,但同时也变“笨”了。具体来说,经过强化学习训练后,模型在单次采样(Pass@1)上的成功率确实上去了,但当你要求它给出多个答案(比如Pass@128)时,你会发现这些答案的多样性很差,模型翻来覆去就是那几种解法,很难探索出新的、同样正确的推理路径。这感觉就像训练一个学生,他通过反复刷题把一种标准答案练得滚瓜烂熟,但遇到需要多角度思考的开放性问题时,却缺乏变通能力。
这个问题的根源,在学术界被称为“概率挤压效应”。听起来有点玄乎,但原理其实很直观。想象一下模型的输出是一个概率分布,每个可能的答案(或推理步骤)都有一个概率值。标准的强化学习,特别是基于对比的优化目标(比如GRPO),其更新机制可以粗略理解为“奖励好的,惩罚坏的”。问题就出在这个“惩罚坏的”环节。当一个答案的概率本身已经很低时,惩罚它的梯度更新,并不会如我们所愿地把这部分“概率质量”均匀地分配给其他潜在的好答案。相反,由于神经网络中Softmax函数的归一化特性,这部分被移除的概率质量,会不成比例地被当前已经占据主导地位的“贪婪答案”吸收。结果就是,模型的输出分布变得越来越“尖峰”,多样性被不断挤压,探索能力自然就受到了限制。
我最初意识到这个问题,是在尝试用GRPO微调一个数学推理模型时。训练曲线显示Pass@1稳步提升,但Pass@128却早早进入了平台期,无论怎么增加训练步数或数据量,都难有突破。这让我开始怀疑,是不是强化学习本身的学习动态就存在某种内在限制?后来读到一些关于学习动态分析的文献,才豁然开朗:这不仅仅是“探索-利用”权衡没做好,而是优化算法本身在概率分布层面的一种系统性偏差。
因此,我们需要的不是简单地加大探索噪声,而是从机制上干预这种“概率挤压”的过程,引导被挤压的概率质量流向那些尚未被充分探索、但可能蕴藏着正确答案的区域。这就是SPS方法的出发点:与其对抗挤压,不如引导挤压。
2. SPS核心思路:用逆强化学习重塑概率分布
2.1 为什么是逆强化学习?
要解决概率挤压,一个直接的思路是引入熵正则化之类的技术,强行给分布“增肥”。但这有点像给一个挑食的孩子硬塞他不爱吃的菜,可能有效,但不够优雅,也容易破坏模型已经学到的有价值模式。我们更希望的是,能有一种方法,可以“示范”给模型看:除了你当前最偏爱的那个答案,还有其他一些看起来也不错的选择,你应该也给它们分一点注意力。
逆强化学习恰恰擅长这个。传统的IRL是从专家演示中反推奖励函数。而我们这里的“专家”从哪来?一个巧妙的思路是:模型自己在强化学习阶段生成的轨迹,就是最好的“演示”来源。这些轨迹里,既有被高奖励选中的“好答案”,也有大量被当前策略忽视的“潜在好答案”。IRL的目标,就是让模型的策略分布,去贴近这个由自身轨迹构成的经验分布。
这样做的好处是双重的:
- 无额外监督:我们不需要任何外部标注的专家数据,完全自给自足。
- 针对性重塑:IRL的损失函数(我们采用前向KL散度)会驱动策略去覆盖那些在经验分布中出现过的轨迹。如果经验分布本身是多样的(我们通过采样策略来保证),那么策略分布就会被拉向一个更平坦、更分散的状态,从而直接对抗概率挤压。
2.2 SPS训练框架:RL与IRL的交替循环
SPS不是一个孤立的算法,而是一个训练框架。它的核心是一个交替循环,如下图所示(概念上):
第一阶段:标准RL探索 这个阶段和普通的强化学习微调没有区别。我们使用例如GRPO这样的算法,在训练数据上进行策略优化。模型会为每个问题生成多个推理轨迹,并根据验证器(比如数学答案检查器)获得二元奖励(正确/错误)。这个阶段的目标是充分利用奖励信号,提升策略在高质量轨迹上的概率。
关键操作:轨迹收集与采样 RL阶段结束后,我们会保存模型生成的所有轨迹(Rollouts)。这里有一个重要的技巧:我们不会使用全部轨迹,而是从中均匀采样一个子集,作为后续IRL阶段的“演示数据”。采样大小是一个超参数,我们的实验发现,在计算效率和多样性之间权衡,采样3条轨迹是一个不错的起点。这一步确保了演示集具有一定的多样性,避免了IRL阶段只模仿少数几个高奖励轨迹。
第二阶段:IRL概率重塑 这是SPS的创新核心。我们将上一步采样得到的轨迹视为“专家演示”,然后通过最小化策略分布与这些演示经验分布之间的KL散度,来更新模型参数。损失函数如下:
L_IRL = - E[KL(π_rollout(y‘|x) || π(y’|x))]
这里,π_rollout 是经验分布(即采样轨迹的分布),π 是当前待优化的策略。使用前向KL散度(P||Q)意味着让策略 π 尽可能覆盖经验分布 π_rollout 支持的所有轨迹。如果经验分布包含了多样化的正确解法,那么策略就会被鼓励给这些解法分配更高的概率,从而将概率质量从过度集中的“尖峰”重新分配到更广阔的“高原”上。
迭代与循环 完成一次IRL更新后,我们得到一个分布更平坦、探索意愿更强的策略。然后,我们将这个策略作为新的起点,重新投入下一轮的RL训练。RL训练会再次利用奖励信号进行聚焦优化,可能又会产生一定的概率集中。但紧接着的IRL阶段会再次将其“摊平”。通过多次这样的RL-IRL循环,我们引导模型在“深度优化已知好解”和“广度探索潜在新解”之间动态平衡,从而逐步扩大模型能够稳定输出的正确推理路径的集合。
实操心得: 这个交替的频率需要仔细调试。IRL阶段太频繁,可能会过度干扰RL学到的奖励信号;IRL阶段太少,又不足以缓解挤压效应。在我们的实验中,每进行一个完整的RL训练周期(例如,在数据集上训练一定步数)后,插入一个短暂的IRL阶段(如4个训练步),效果比较稳定。
3. 关键技术细节与实现要点
3.1 低似然轨迹强调
在分析概率挤压效应时,我们发现一个关键现象:挤压最严重的时候,往往发生在对那些模型本身赋予极低概率的负样本进行惩罚时。这些样本虽然被模型认为是“烂答案”,但其中可能混杂着一些只是当前策略尚未理解的、非典型的正确解法。如果IRL阶段只是均匀地从所有轨迹中采样,这些低似然轨迹很可能因为采样不到而被忽略。
为此,我们提出了 低似然轨迹强调 策略。具体来说,在从RL阶段收集的轨迹池中采样时,我们不是完全均匀采样,而是有意提高那些在当前策略下条件概率较低的轨迹被选中的权重。这样做的目的是主动“打捞”那些被模型忽视的潜在好轨迹,并在IRL阶段将其作为演示,强行告诉模型:“这些路径虽然你现在觉得不靠谱,但它们也是可行的,你应该重视它们。”
实现方式:
- 对于RL阶段生成的每条轨迹
y_i,计算其在当前策略π_old下的对数概率log π_old(y_i | x)。 - 根据对数概率的负值(或倒数)来构建采样权重。概率越低的轨迹,权重越高。
- 按照此权重从轨迹池中采样出用于IRL的演示批次。
注意事项: L2TE策略需要谨慎使用。如果过度强调极低概率的轨迹(其中很多可能就是真正的错误答案),可能会引入噪声,甚至让模型学到错误的模式。一个稳妥的做法是设置一个概率阈值,只对那些概率高于某个极小值(表明不是完全随机垃圾)但显著低于高奖励轨迹的样本进行强调。同时,为了保证训练稳定性,每个IRL批次中仍需混合一定比例的高奖励(正样本)轨迹。
3.2 训练超参数与稳定性
SPS框架引入了新的超参数,合理的设置对成功至关重要:
- IRL学习率:IRL阶段的学习率通常需要设置得比RL阶段小1-2个数量级。例如,RL学习率为5e-7,IRL学习率可设为5e-9到5e-10。这是因为IRL的目标是温和地重塑分布,而非剧烈地改变策略。过大的学习率可能导致训练不稳定或灾难性遗忘。
- IRL训练步数:每次切换到IRL阶段时,训练的步数不宜过多。我们的经验是4-10步足以产生有效的分布调整。过多的IRL步数可能导致模型过度拟合当前批次的演示轨迹,反而损害泛化能力。
- 演示采样大小:如前所述,这是平衡多样性与计算成本的关键。对于大多数实验,我们从每个问题生成的8条轨迹中采样3条作为演示。你可以根据任务难度和计算资源进行调整。任务越复杂,可能需要更多的演示样本来覆盖解空间。
- 循环周期:即每隔多少RL训练步(或多少个训练样本)执行一次IRL阶段。这需要根据数据集大小和模型收敛速度来定。一个常用的策略是每完成一个对训练集的完整遍历(epoch)后,进行一次IRL阶段。
稳定性技巧:
- 梯度裁剪:在IRL阶段同样应用梯度裁剪,防止个别轨迹的梯度造成过大更新。
- 混合演示:确保每个IRL批次中既包含通过L2TE策略采样的低似然轨迹,也包含随机采样或高奖励的轨迹,以避免分布过度偏斜。
- 验证集监控:不仅要监控Pass@1,更要紧密监控Pass@k(k=32, 128等)在验证集上的表现。Pass@k的提升是SPS是否起效的直接标志。如果Pass@k下降,可能需要调低IRL的强度或学习率。
3.3 与基线方法的对比实现
为了将SPS集成到现有训练流程中,我们以流行的GRPO算法为基础进行修改。以下是简化的伪代码逻辑,展示了如何在一个训练循环中嵌入SPS:
与DAPO、GSPO等基线相比,SPS的额外开销主要在于:
- 存储开销:需要缓存RL阶段产生的轨迹。
- 计算开销:额外的IRL前向-后向传播。但由于IRL阶段步数少、学习率低,这部分开销相对于整个RL训练周期来说是较小的。
4. 实验效果分析与解读
我们在多个奥林匹克数学竞赛数据集上验证了SPS的有效性,包括AIME、BRUMO和HMMT。基座模型选用Qwen2.5-Math的1.5B和7B版本。对比的基线包括标准的GRPO以及同期其他先进的RL方法。
4.1 核心性能指标:Pass@k的提升
实验的核心结果是令人振奋的。下表展示了在3K训练数据规模下,Qwen2.5-Math-1.5B模型上的部分结果对比:
| 方法 | 模型大小 | BRUMO (Pass@128) | AIME-25 (Pass@128) | HMMT-FEB (Pass@128) |
|---|---|---|---|---|
| Base Model | 1.5B | 43.33 | 46.67 | 23.33 |
| + GRPO | 1.5B | 53.33 | 43.33 | 20.00 |
| + SPS (Ours) | 1.5B | 63.33 | 46.67 | 26.67 |
可以看到,SPS在Pass@128指标上相比原始GRPO有显著提升。例如在BRUMO上,从53.33提升到了63.33,绝对增益达到10个百分点。这直接证明了SPS有效增强了模型的探索能力:模型现在能从更丰富的推理路径中找到正确答案,而不是仅仅依赖一两种最熟悉的模式。
更重要的是,SPS在提升Pass@k的同时,保持了Pass@1(单样本成功率)不下降甚至略有提升。这说明我们的方法不是以牺牲单样本精度为代价来换取多样性,而是实现了两者的协同提升。概率质量被重新分配到了更多正确的区域,而不是简单地均匀散开。
4.2 不同模型规模与数据规模的敏感性分析
一个有趣的发现是,概率挤压效应和SPS的有效性,与模型规模密切相关。
- 小模型(1.5B):基座模型本身能力有限,输出分布可能已经比较“尖锐”(倾向于少数模式)。标准的GRPO训练会加剧这种挤压,导致Pass@128甚至可能低于基座模型。此时,SPS的干预效果尤为明显,它能显著缓解挤压,将Pass@128拉高。
- 大模型(7B):基座模型本身知识更丰富,分布可能更平滑。GRPO有时能利用其内部知识实现一定的探索增益。但SPS依然能在其基础上带来进一步的提升,尤其是在数据量受限(3K)的情况下,帮助模型更充分地挖掘有限数据中的多样性。
关于数据规模,我们在3K、5K、10K三种数据量上进行了实验。结果表明,SPS在小数据场景下优势更明显。当数据只有3K时,SPS相比GRPO的增益最大。这是因为数据越少,模型越容易过拟合到有限的几种模式上,概率挤压效应越强,SPS的“分布重塑”作用也就越关键。
4.3 采样大小的影响
我们专门对IRL阶段演示轨迹的采样大小进行了消融实验。如下图所示,随着采样大小从1增加到5,模型在多个数据集上的Pass@128性能呈现单调上升趋势。这很好理解:更多的演示样本为IRL提供了更丰富的分布信息,有助于引导概率质量向更广阔的空间扩散。
实践建议: 虽然增大采样大小有益,但也会增加计算和存储成本。根据我们的经验,对于大多数任务,采样大小设为3是一个性价比很高的选择。如果你的任务解空间极其复杂,或者计算资源充足,可以尝试增加到5。
5. 常见问题与实战排坑指南
在实际复现和应用SPS的过程中,你可能会遇到以下几个典型问题:
5.1 训练不稳定或性能下降
- 症状:引入IRL阶段后,训练损失剧烈波动,或者Pass@1大幅下降。
- 可能原因与排查:
- IRL学习率过高:这是最常见的原因。IRL旨在微调分布,学习率必须远低于RL阶段。解决方案:将IRL学习率设置为RL学习率的1/100到1/1000,例如RL用5e-7,IRL用5e-9。
- IRL步数过多:IRL阶段训练太久,导致模型“忘记”了RL阶段学到的奖励信号。解决方案:严格控制IRL步数在个位数(如4步),并监控验证集上Pass@1的变化。
- 演示轨迹质量太差:如果L2TE策略过于激进,采样了太多真正错误的低概率轨迹作为“演示”,会误导模型。解决方案:调整L2TE的采样权重,避免采样概率极低(如低于1e-10)的轨迹;或者在演示批次中确保至少包含一定比例已知的高奖励轨迹。
- 奖励稀疏性问题:在推理任务中,奖励是二元的(对/错),本身就非常稀疏。SPS并不能直接创造奖励信号,它只是重新分配概率。如果数据集中很多问题本身就很难,模型生成的正确轨迹极少,那么IRL可用的“好演示”也少,效果会打折扣。解决方案:考虑先使用SFT或更丰富的奖励信号进行预热,或者结合课程学习,从简单问题开始训练。
5.2 Pass@k没有提升,甚至下降
- 症状:训练后Pass@1正常,但Pass@128没有变化或变差。
- 可能原因与排查:
- 概率挤压不严重:对于某些本身分布就较平坦的模型或简单任务,概率挤压可能不是主要瓶颈。此时SPS的增益有限。诊断方法:检查基座模型和RL训练后模型的输出分布熵,或者直接观察生成文本的多样性。
- IRL阶段未起作用:学习率过低、步数过少,或者演示采样方式有问题,导致IRL的梯度更新可以忽略不计。解决方案:检查IRL阶段的损失值是否在下降;尝试增大IRL学习率(在稳定前提下)或采样大小。
- 评估方式问题:Pass@k的评估需要足够的采样数(如128),且评估时的采样温度需要与训练时或实际应用场景匹配。解决方案:确保评估设置合理,温度不宜过低(建议0.7-1.0),采样数足够大以反映多样性。
5.3 计算与存储开销过大
- 症状:训练速度明显变慢,GPU内存占用增加。
- 可能原因与排查:
- 轨迹存储:存储所有RL轨迹会占用大量内存。解决方案:使用动态缓存或磁盘缓存,或者只存储最近几个批次的轨迹。在采样演示后及时清空。
- IRL计算:IRL损失计算需要遍历演示批次中的所有token,对于长序列任务开销较大。解决方案:可以考虑对演示轨迹进行截断,或者使用更高效的KL散度计算实现。
- 循环频率过高:每训练一个批次就进行IRL,开销巨大。解决方案:合理设置IRL的触发间隔,例如每1000个RL步或每个epoch结束时进行一次。
5.4 超参数调优策略
SPS引入了新的超参,手动调优费时费力。建议采用以下策略:
- 先固定RL部分:使用一个稳定的RL配置(如GRPO)作为基线。
- 调IRL学习率:这是最重要的参数。从一个非常小的值开始(如RL学习率的1e-3),根据验证集Pass@k的变化进行微调。如果Pass@k上升但Pass@1下降,说明学习率可能偏大;如果两者都没变化,可以适当调大。
- 再调采样大小和L2TE强度:在稳定学习率下,调整演示采样大小和L2TE的权重。观察验证集多样性和准确率的平衡。
- 最后确定循环周期:根据训练集大小和模型收敛速度,调整IRL阶段的触发频率。
我个人在多个数学推理和代码生成任务上应用SPS的经验是,它像是一个“分布平滑器”和“探索助推器”。对于那些标准RL训练后陷入局部最优、输出变得单调的任务,SPS往往能带来惊喜。它的价值不仅在于提升了某个指标,更在于为我们提供了一种新的视角:将强化学习视为一个动态的概率分布演化过程,并通过逆强化学习对其进行有目的的干预。这种思路,或许能启发我们解决更多大模型训练中关于探索、多样性与稳定性平衡的难题。