【论文解读】《Training Large Language Models to Reason in a Continuous Latent Space》

news/2025/2/24 13:42:34

论文链接

1. 背景与动机

  • 语言空间与推理的矛盾
    目前大多数大语言模型(LLMs)在解决复杂问题时采用链式思维(Chain-of-Thought, CoT)方法,即利用自然语言逐步推导出答案。然而,论文指出:

    • 自然语言主要为文本连贯性服务,很多生成的词令(tokens)在推理上并非关键。
    • 一些关键推理步骤需要复杂规划,但用语言表达往往使模型过早做出确定性选择,丧失灵活性。
    • 从神经影像学的研究来看,人脑在进行推理任务时,其语言网络并不活跃,说明语言其实是为交流而优化,而非专门用于推理。
      因此,作者提出:为何不让模型在一个“无限制的隐空间”中进行推理,再在必要时将结果转换为语言?
  • 连续隐向量作为推理状态
    论文引入了一个新的范式——Coconut(Chain of Continuous Thought)。其核心思想在于:

    • 利用模型最后一层的隐藏状态(hidden state)作为当前的“连续思维”(continuous thought),代表模型的内部推理状态。
    • 不通过语言头将其解码为文字,而是直接将这一隐向量反馈给模型,作为下一个时间步的输入嵌入。
      这样可以让模型在没有语言约束的情况下自由推理,同时仍然可以端到端地利用梯度下降进行训练citeturn0file0。

2. 方法与架构设计

在这里插入图片描述

2.1 基本模型与模式切换
  • 标准语言模型的表示
    给定一个输入序列 $ x = (x_1, x_2, \dots, x_T) $,模型通过嵌入
    E t = [ e ( x 1 ) , e ( x 2 ) , … , e ( x t ) ] E_t = [e(x_1), e(x_2), \dots, e(x_t)] Et=[e(x1),e(x2),,e(xt)]
    得到隐藏状态 $ H_t $,最终通过 softmax 预测下一个 token(即
    M ( x t + 1 ∣ x ≤ t ) = s o f t m a x ( W h t ) M(x_{t+1}|x_{\le t}) = \mathrm{softmax}(W h_t) M(xt+1xt)=softmax(Wht)
    )。

  • 从语言模式到隐模式的转换
    Coconut 的核心改动在于:

    • 语言模式(Language Mode):与传统方法一致,模型生成词令序列。
    • 隐模式(Latent Mode):在特定区间(由特殊 token 标记,如 <bot><eot>)内,模型不再使用词嵌入,而是直接使用前一步的隐藏状态作为下一个输入。这一过程即“连续思维”:
      • 假设位置 $ i $ 处为 <bot>,位置 $ j $ 为 <eot>,那么在 $ i < t < j $ 的区域,输入为 $ h_{t-1} $ 而非 $ e(x_{t-1}) $。
    • 当隐模式结束后,模型恢复使用常规词嵌入继续生成。
2.2 多阶段训练策略
  • 训练目标与梯度传递
    由于连续思维完全可微,论文采用标准的负对数似然损失(negative log-likelihood)进行训练,不过会对问题描述和隐思维部分进行掩码处理,确保损失只计算在剩余的语言输出上。

  • 逐步替换语言推理
    受到 Deng 等(2024)的启发,作者设计了一个多阶段训练课程:

    • 初始阶段:使用完整的语言推理链(CoT)的数据训练模型。
    • 后续阶段:逐步将语言推理步骤替换为连续隐思维。这里引入超参数 ( c ),表示每一步语言推理被替换为 ( c ) 个连续思维。
    • 如果原始推理链不足 ( k ) 步,则将全部推理步骤替换。每换一次阶段,都重置优化器状态以便更好地适应新的训练目标。
    • 此外,在连续思维的开始和结束处分别插入 与 标记。
2.3 推理过程

在这里插入图片描述

  • 推理时的模式切换
    在推理阶段,与训练类似:

    • 模型在处理完问题后(即问题部分用语言模式处理完毕)插入 token,随后进入隐模式,直接使用隐藏状态进行推理。
    • 对于何时结束隐模式,论文提出两种策略:
      1. 训练一个二分类器让模型自主决定何时结束隐推理。
      2. 固定隐推理的步数,即用固定长度的连续思维。
    • 实验中,为了简单起见,两种方法表现相近,因此作者采用了固定步长的方案。
  • 多次前向传播计算
    在训练中,如果当前阶段有 ( n ) 个隐思维,则需要进行 ( n+1 ) 次前向传播来依次生成每个隐向量,最后一次前向传播用于计算剩余文本的损失。这种多次前向传播虽然可以借助 KV 缓存加速,但由于依赖前一步计算,仍然存在并行性挑战。


3. 实验设置与比较

3.1 数据集与任务

论文在三个数据集上评估模型性能,分别侧重不同的推理能力:

  • 数学推理(GSM8k)

    • 包含小学水平的数学题,题目多样且贴近实际。
    • 训练时使用 Deng 等(2023)生成的合成数据集。
  • 逻辑推理(ProntoQA)

    • 题目利用虚构概念构造,要求模型根据给定条件判断某个陈述是否正确。
    • 由于题目结构较简单,要求模型做出直观的下步预测。
  • 规划密集型逻辑推理(ProsQA)

    • 为解决 ProntoQA 中分支较少的问题,作者设计了一个新的数据集 ProsQA,其推理条件构造为随机生成的有向无环图(DAG),要求模型在较为复杂的图结构中搜索正确推理链。
3.2 基线与变种

论文与多种基线方法进行比较,包括:

  • CoT(Chain-of-Thought)
    完整生成推理链后再给出答案。

  • No-CoT
    模型直接生成答案,不包含任何中间推理步骤。

  • iCoT
    采用内部化链式推理的策略,在训练过程中逐步移除推理链中的前几步(Deng et al., 2024)。

  • Pause Token
    在问题与答案之间插入特殊 tokens,赋予模型额外计算能力(Goyal et al., 2023)。

另外,还探讨了Coconut的几种变体:

  • w/o curriculum:直接使用仅包含问题和答案的最后阶段数据训练,而不使用多阶段训练。
  • w/o thought:虽使用多阶段训练但不使用任何连续隐思维,相当于仅移除语言推理步骤。
  • pause as thought:用 tokens 代替连续隐思维,采用相同的多阶段训练策略。
3.3 实验结果
  • 总体表现(参见 Table 1)

    • 在 GSM8k 上,标准 CoT 的准确率为 42.9%(生成 25 个 token),而 Coconut 达到 34.1%(生成仅 8.2 个 token),说明在生成效率上有明显优势。
    • 在逻辑推理任务 ProntoQA 中,Coconut 与 iCoT 均达到了 99.8% 的高准确率,但生成 token 数量显著减少(9.0 vs. 3.0~92.5 token,不同基线有所不同)。
    • 在规划要求更高的 ProsQA 上,Coconut 的准确率达到 97.0%,明显优于传统 CoT(77.5%)且生成 token 数也较少(14.2)。
  • 超参数 ( c ) 的影响
    实验表明,在 GSM8k 上,当每步隐思维的数量 ( c ) 从 0 增加到 2 时,模型性能呈稳步提升(见 Figure 3),说明“链式”连续思维能在隐空间中积累更多有效信息。

  • 推理效率与时钟时间
    除了准确率外,论文还比较了不同方法在推理过程中新生成 token 数量和平均推理时间,Coconut 在保持高准确率的同时大幅减少了生成 token 数,从而加快了推理速度(参见附录 B)。


4. 隐空间推理的深入分析

论文不仅在实验上展示了 Coconut 的优势,还对隐推理过程进行了详细的剖析与解释:

4.1 推理过程的隐搜索树解释
  • 多候选路径编码
    由于连续隐向量可以同时编码多个可能的下步推理,作者将其解释为一种隐式的广度优先搜索(BFS):

    • 在隐模式中,模型并没有立即确定唯一的下步选择,而是保留多个可能性,并在后续逐步淘汰不正确的路径。
    • 这种机制使得模型在遇到复杂规划任务时更为稳健,能够在面对多个分支时延迟决策。
  • 隐式价值函数
    当模型从隐空间切换回语言模式时,可以观察到预测分布中各候选项的概率。作者将这种概率分布视为一种隐式的“价值函数”,用于评估每个候选路径(例如在图结构中的“子节点”)通向正确答案的潜力(参见 Figure 7 和 Figure 8)。

4.2 隐推理与语言推理的对比
  • 延迟决策与规划能力
    在传统 CoT 中,每一步生成都会“锁定”一个具体的文本描述,容易导致过早决策;而在隐空间中,模型可以延迟决策,利用后续信息逐步修正路径,从而在规划密集型任务(如 ProsQA)中表现更优。

  • 节点高度与评价准确度
    论文还提出了一个分析方法:

    • 定义搜索树中节点的“高度”为该节点到叶子节点的最短距离。
    • 分析表明,对于高度较低的节点(即后续探索空间有限),模型能够更准确地分辨正确与错误的选项。而对于高度较高的节点,由于潜在分支较多,模型的区分能力会下降(见 Figure 9)。
4.3 模型平行探索的变化
  • 从宽广探索到聚焦收敛
    分析显示,在第一隐思维阶段,模型在候选路径上具有较高的多样性(即并行探索),而在第二阶段后,多数候选分布迅速收敛到少数高概率路径。这种变化表明模型在初期保持探索性,随后逐步聚焦到最有希望的解答路径。

5. 结论与未来方向

  • 主要贡献

    • 提出了 Coconut 这一全新的在连续隐空间中进行推理的方法,突破了传统 CoT 依赖自然语言表达的局限。
    • 实验结果表明,尤其在规划密集型任务中,Coconut 能够提高推理准确率,同时大幅减少生成的 token 数,从而提升推理效率。
    • 通过对隐搜索树的分析,展示了模型如何在隐空间中延迟决策、并行探索并最终收敛到正确解答。
  • 未来工作

    • 如何进一步优化多阶段训练过程、提高并行计算效率;
    • 探索预训练阶段就引入连续隐思维,从而使模型能在更广泛的推理任务上泛化;
    • 结合语言与隐空间推理的优势,开发更加高效且鲁棒的推理系统。

总结

这篇论文系统地阐述了一种新的大语言模型推理方法——Coconut,其核心在于让模型在一个连续的、无限制的隐空间中进行推理,通过多阶段训练逐步将传统语言推理替换为连续隐向量。实验结果和细致的分析表明,这种方法在逻辑、数学和规划密集型任务上均能展现出较传统方法更高的效率和准确率,同时为理解大模型内部推理机制提供了新的视角。


http://www.niftyadmin.cn/n/5864403.html

相关文章

力扣——搜索二维矩阵

题目链接&#xff1a; 链接 题目描述&#xff1a; 思路&#xff1a; 可以发现&#xff0c;如果把每一行拼起来&#xff0c;就是一个递增的数组&#xff0c;可以在这个递增的数组上使用二分法找到target如果拼起来的某个元素索引是i&#xff0c;那它在二维矩阵里面的索引是【…

【C++】list 链表的使用+模拟实现

目录 文章目录 前言 一、list的简介 二、list的使用方法 三、list的模拟实现 1.基本框架&#xff1a; 2.迭代器实现 3.常用接口实现 四、完整代码 总结 前言 本文主要介绍C【STL】容器中的 list&#xff0c;包括接口说明和模拟实现。其中讲解了迭代器功能上的分类&am…

哈希表入门到精通:从原理到 Python 实现全解析

系列文章目录 01-从零开始掌握Python数据结构&#xff1a;提升代码效率的必备技能&#xff01; 02-算法复杂度全解析&#xff1a;时间与空间复杂度优化秘籍 03-线性数据结构解密&#xff1a;数组的定义、操作与实际应用 04-深入浅出链表&#xff1a;Python实现与应用全面解析 …

如何在望获实时 Linux 京博航友善 NanoPC-T6 上部署 Docker

在数字化浪潮席卷各行业的当下&#xff0c;开发者们对于高效、稳定开发环境的追求从未停歇。望获实时 Linux 与京博航友善 NanoPC-T6 开发板的组合&#xff0c;为开发者们提供了一个强大的平台。本文将详细介绍如何在这套平台上部署 Docker 环境&#xff0c;助力开发者们快速构…

登录-07.JWT令牌-登录后下发令牌

一.思路 我们首先完成令牌生成。 在响应数据这一块 该响应数据是一个标准的Result结构&#xff0c;其中"data"的值就是一个JWT令牌。因此我们只需要将生成的JWT令牌封装在Result当中然后返回给前端即可。 备注是给前端看的&#xff0c;不用管。以后我们做校验时&…

便携式动平衡仪Qt应用层详细设计方案(基于Qt Widgets)

便携式动平衡仪Qt应用层详细设计方案&#xff08;基于Qt Widgets&#xff09; 版本&#xff1a;1.0 日期&#xff1a;2023年10月 一、系统概述 1.1 功能需求 开机流程&#xff1a;长按电源键启动&#xff0c;全屏显示商标动画&#xff08;快闪3~4次&#xff09;。主界面&…

NavVis VLX三维扫描:高层建筑数字化的革新力量【沪敖3D】

在三维激光扫描领域&#xff0c;楼梯结构因其复杂的空间形态和连续垂直移动的实际需求&#xff0c;一直是技术难点之一。利用NavVis VLX穿戴式移动扫描系统成功完成一栋34层建筑的高效扫描&#xff0c;其中楼梯部分的数据一遍成形且无任何分层或形变。本文将深入分析该项目的技…

python读取sqlite温度数据,并画出折线图

需求&#xff1a; 在Windows下请用python画出折线图&#xff0c;x轴是时间&#xff0c;y轴是温度temperature 和体感温度feels_like_temperature 。可以选择县市近1小时&#xff0c;近1天&#xff0c;近1个月的。sqlite文件weather_data.db当前目录下&#xff0c;建表结构如下…