扩散模型之(九)从 Flow Matching 到 Rectified Flow

张开发
2026/5/4 8:34:40 15 分钟阅读
扩散模型之(九)从 Flow Matching 到 Rectified Flow
图 1.校正流模型学习到的采样路径比标准流匹配模型更直,从而实现更快的模拟。Notes两个模型都经过训练,以从相同的目标分布中生成样本。校正流更直的路径允许以更少的步骤进行精确的数值积分,从而降低模拟的计算成本并减少延迟.0. 引言基于流的生成模型已成为一类强大的模型能够生成图像、视频等复杂数据的高质量样本。这类模型借助神经网络通过执行一系列可逆变换将随机噪声转化为复杂数据同时支持新样本生成和似然估计。流模型的成功在一定程度上得益于流匹配技术的提出 —— 该技术让模型训练无需进行计算成本高昂的模拟还能适配任意噪声分布。但在实际大规模部署流模型时存在一个关键障碍生成高质量样本需要反复运行参数量动辄数十亿的大型神经网络。这不仅会产生极高的计算成本还会带来严重的延迟问题部分场景下生成单个样本甚至需要数分钟。因此提出能减少神经网络调用次数、实现模型加速的方法已成为亟待解决的需求。从流模型中采样的高成本核心诱因是所学流的几何特性。高维数据的分析难度较大但幸运的是可以通过低维空间的可视化直观理解流的诸多重要几何特性。事实上训练大规模模型的算法完全可直接用于在简易的数据分布上训练二维流复现诸多具有实际研究价值的现象。从流模型中采样的过程本质是模拟抽象粒子的运动轨迹粒子从随机噪声状态向真实数据状态移动神经网络会被反复调用以确定粒子在每个时间点的速度。当这些轨迹的曲率较大时要实现精准模拟就需要通过昂贵的神经网络执行大量小步计算。如下图 2 所示一个为生成笑脸分布样本训练的流模型其生成的轨迹存在明显曲率。本文的核心研究内容正是该曲率问题、其引发的后果以及对应的缓解方法。本文将探讨流模型生成的轨迹为何会呈现这种几何形态、为何难以高效模拟同时介绍一种名为矫正流的简易方法 —— 它能拉直流模型的轨迹实现更快的采样。图 2.流匹配模型生成弯曲的采样轨迹此图展示随机噪声源分布随时间向笑脸目标分布的转变1.背景知识在深入分析流匹配训练的模型为何会生成弯曲轨迹以及矫正流如何解决该问题之前本文先介绍基于流的生成模型和流匹配的基础理论。1.1 基于流的生成模型生成建模的核心目标是从存在经验观测、但真实分布未知的复杂数据分布如自然图像分布中抽取样本。具体来说给定目标分布的有限样本集我们的任务是训练一个模型使其能生成目标分布的新样本。流模型的核心原理是定义一种连续变换搭建易采样的简单源概率分布如多元高斯分布与复杂数据分布之间的桥梁。我们定义连续的概率分布序列称为概率路径它能实现源分布​到数据分布的平滑插值.图 3.该路径由时间变量索引其中对应源分布对应目标分布。从源分布​中抽取样本并随时间执行变换即可生成符合数据分布的样本图中还可以看到单个样本从源分布移动到目标分布时的轨迹流是从到的时间索引映射用于定义点在时间维度的运动轨迹。将流作用于源分布的样本可将其从源分布传输至目标分布得到。基于流的模型的训练目标是学习一个流使得对于任意时间经流变换后的点其分布均与概率路径中对应时间的分布一致即。若能成功建模该流就能从简单的源分布​中采样再通过变换得到真实数据分布的逼真近似样本。图 4.单个样本的轨迹如何从源分布移动到目标分布一个看似反直觉的设计是基于流的生成模型并非直接建模流而是建模能 “生成” 流的时变速度场。通过该速度场求解常微分方程ODE即可还原出流这一过程称为模拟。从时间的初始点出发可根据速度场通过下述常微分方程追踪点的时间轨迹该含速度场的常微分方程的解就是流本身。目前有多种数值方法可模拟这类常微分方程其核心是通过一系列离散步骤逼近连续轨迹其中最简单的是欧拉法在每个时间步沿速度场方向执行小步线性计算公式为图 5.随时间变化的速度场进行欧拉积分欧拉法近似沿速度场方向采样16个离散时间步1.2 flow matching掌握基于流的生成模型的基础后我们来介绍流匹配技术。本文仅针对与矫正流相关的核心概念做高层概述更详尽的介绍可参考文末资料。流匹配的研发初衷是让速度场的学习摆脱高成本的模拟过程即无需通过欧拉积分或其他方法求解常微分方程。借助流匹配我们只需求解一个简单的回归损失就能完成速度场的训练。流匹配的流程可分为两个核心步骤定义用于实现源分布和目标分布插值的概率路径通过回归训练生成该概率路径的速度场。1.2.1 步骤 1定义概率路径本文重点介绍一种特定的概率路径 ——线性路径它通过源分布和目标分布的简单线性插值定义在本文的所有示例中源分布​均为标准高斯分布目标分布为代表笑脸的复杂二维分布。但在实际应用中流匹配对概率路径和源分布的选择具有更高的灵活性。图 6.源点和 目标点之间的线性插值产生时刻处的插值样本1.2.2 步骤 2速度场回归流匹配的第二步是通过优化简单的回归目标将真实速度场与由神经网络参数化的近似速度场进行匹配回归目标为但这里存在一个问题我们无法直接获取真实速度场。在实际场景中真实速度场主导着两个联合分布的高维分布间的变换难以直接构造。那么该如何优化这一目标幸运的是我们可通过将速度场基于目标分布的单个样本进行条件约束构建一个相关但更简单的目标由此得到条件速度场图 7.特定目标点的条件速度场可以由一束从源分布点指向目标点的射线表示基于该条件向量场我们可构建名为条件流匹配的回归目标将线性概率路径对应的条件速度场代入上式可得到形式极简洁的训练目标值得注意的是条件流匹配和流匹配的目标函数具有相同的梯度即这意味着我们只需优化易处理的条件流匹配目标就能解决流匹配的核心问题。训练过程中只需从源分布和目标分布中抽取样本对通过插值得到​再训练速度场使其能预测线性速度即可。图 8.条件流匹配训练一个学习到的速度场使其匹配条件速度。由于学习到的速度场并不以目标点为条件它必须仅使用当前位置来推断可能的目标位置这导致真实速度和预测速度之间存在一定的误差而流匹配会将这种误差最小化需要重点强调的是我们将基于目标点​的条件速度与仅包含当前点信息的所学速度场进行匹配。若所学向量场也基于​ 做条件约束该问题会变得极为简单 —— 模型只需预测的某个缩放版本即可。因此模型需要仅通过时间的位置信息​判断出样本可能的目标终点.2. 问题分析掌握流模型和流匹配的基础后我们来分析其实际应用中的特殊问题。前文已指出流匹配训练的流模型会生成弯曲的轨迹见图 2将源分布和目标分布叠加后能更直观地看到这种曲率的极端性见图 9。图 9.流匹配生成弯曲的采样轨迹当我们将轨迹与目标分布叠加时这种曲率更明显2.1 曲率加载弯曲轨迹可视化细心的读者可能会产生疑问我们基于线性路径训练速度场使其匹配线性轨迹为何模型最终却学习到了弯曲的轨迹而这一问题又为何需要重视后者的答案更为直观曲率会严重影响采样速度。2.2 曲率是速度的天敌从流模型中抽取新样本时需要利用训练好的速度场执行数值积分。欧拉法等数值积分算法的核心是沿速度场方向执行有限步计算​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​本质是对 “真实轨迹” 做局部线性近似。近似的精准度取决于轨迹的曲率大小以及在不偏离真实轨迹、不降低样本质量的前提下可执行的步长。图10.高曲率(左)和低曲率(右)函数的欧拉方法近似值比较。黑色显示的是真实值,橙色显示的是欧拉近似值。高曲率轨迹需要很多步才能准确模拟核心结论是曲率是速度的天敌。曲率过大的轨迹难以通过少量步骤实现精准模拟这意味着需要反复调用代表向量场的大型神经网络来逼近轨迹最终导致高延迟和高计算成本。但模型最初为何会学习到弯曲的轨迹答案与源随机变量和目标随机变量的联合分布方式有关这一概念被称为耦合。2.3 什么是耦合通过流匹配训练速度场时需要从源分布和目标分布中抽取样本对。在流匹配的介绍中我们略过了样本对的具体抽取方式而这一设计选择至关重要 —— 它被称为耦合会对所学流的几何形态产生显著影响也是轨迹产生曲率的核心原因。耦合是源随机变量和目标随机变量间的联合分布决定了训练所用样本对的分布规律。耦合的核心要求是其边缘分布需分别为源分布和目标分布.本文研究的是最简单的耦合形式 ——独立耦合见图 11独立地从源分布和目标分布中抽取和联合分布满足这种方式能让训练中的样本对构造变得极为简便也是源分布和目标分布样本间无已知关联结构时的自然选择。如前文所述独立耦合是轨迹产生曲率的核心原因。从下图 11 中可看到连接独立抽取的源点和目标点的线段存在大量交叉这些交叉会在路径中产生分支而所学速度场无法解析这些分支最终导致轨迹弯曲。图 11.将随机源点连接到随机目标点的独立耦合可视化效果独立耦合的一种替代方案是最优传输耦合见下图 12它以最小化源分布到目标分布的整体传输成本为原则连接源点和目标点生成的路径交叉更少能得到更平直的轨迹。但最优传输耦合的计算难度更高尤其在高维空间中因此在实际应用中并不常用。图12.通过Sinkhorn算法计算的最优传输耦合OTC的可视化.与独立耦合不同最优传输耦合可最小化传输成本,从而使路径的纠缠和交交叉更少2.4 不合时宜的路径交叉所学流模型无法精准建模独立耦合产生的交叉路径这一局限性最终表现为轨迹的弯曲。具体来说若样本对和形成的两条路径在时间于点处相交或近似相交则所学速度场需要在同一位置和同一时间匹配两个不同的速度和——而这是无法实现的因为所学速度场仅为当前位置和时间的函数。图13.两组样本对和在点发生交叉速度场无法准确预测两个相互冲突的速度--它所能做的最好的就是条件期望(绿色箭头)在交叉点处所学速度场无法精准预测两个期望速度最终会输出二者的平均值。这一规律具有普遍性当大量路径在小邻域内相交时所学速度场会通过计算经过该点的速度的条件期望对相互冲突的速度做平均处理。Notes由于这些交叉点的平均速度会随空间位置变化最终就形成了弯曲的轨迹。因此即便我们训练流模型去匹配线性速度最终得到的仍是弯曲的轨迹。3. Rectified flow我们已分析了曲率大的轨迹难以通过少量步骤模拟的原因也理解了使用独立耦合时流模型学习到弯曲轨迹的机理。现在可以提出核心问题如何让模型学习到更平直的轨迹矫正流正是该问题的解决方案结合前文的背景知识来看这一方案其实简洁且直观。3.1 算法原理矫正流的核心是用模型自身生成的耦合替代传统流匹配训练中简单的独立耦合从而拉直流的轨迹。具体流程为首先通过独立耦合利用流匹配训练一个基础模型接着从源分布中抽取将其输入所学流模型得到由此生成新的样本对再以该新耦合为基础重新训练一个新的流模型。重复上述过程多次就能逐步拉直流模型的轨迹。完整的算法流程如下Reflow过程。3.2 reflow过程输入源分布、目标分布、迭代次数输出矫正速度场从独立耦合中抽取样本对令执行循环基于耦合的样本对训练速度场从源分布抽取将其输入速度场生成新样本对更新耦合循环结束输出矫正速度场​。算法 1 说明Reflow 过程通过基于前序模型生成的耦合反复训练实现轨迹的逐步拉直。3.3 算法有效性原理从训练好的流模型中采样的过程本质是求解如下形式的常微分方程​​​​​​​ ​​​​​​​ ​​​​​​​该方程定义的是确定性流在温和的正则性条件下其轨迹具有唯一性—— 这一特性是理解矫正流工作原理的关键。确定性流的轨迹唯一性意味着两条不同的轨迹无法在同一时空点处相交。若相交则两条轨迹在所有时间点都必须重合这与 “轨迹不同” 的前提矛盾。因此确定性流禁止轨迹的交叉、分支和融合而通过积分该流得到的耦合也继承了这种确定性。将源分布的样本输入所学流模型生成新的样本对时得到的耦合必然满足轨迹无交叉。基于该耦合重新训练模型能从根本上消除导致轨迹弯曲的交叉点处的速度冲突问题最终实现轨迹的拉直。图14.流模型所引发的耦合产生的路径纠缠较少常规的独立耦合产生的路径频繁交叉。如果通过学习到的流模型让源点流动来生成诱导耦合将得到一个交叉点显著减少的耦合对4. 模型对比我们可对比标准流匹配模型与矫正流模型学习到的轨迹见下图 15矫正流模型的轨迹平直度显著提升仅需更少的步骤就能实现精准模拟。图15.流模型所引发的耦合产生的路径纠缠较少从一种一般的独立耦合其中路径频繁交叉。如果通过学习到的流模型让源点流动来生成诱导耦合将得到一个交叉点显著减少的耦合对轨迹曲率的差异直接决定了采样所需的步骤数。通过对比欧拉法在不同积分步数下对 “真实轨迹”通过大量步数模拟得到的逼近效果见下图16能清晰看到这一影响矫正流模型即便仅用极少的步数也能生成精准的近似轨迹而流匹配模型的弯曲轨迹会导致其模拟结果与真实轨迹出现显著偏差。图 16.与标准流匹配模型相比矫正流模型学习到的采样路径更直,从而能够实现更快的模拟。两个模型都经过训练以从相同的目标分布中生成样本。矫正流更直的路径允许以更少的步骤进行精确的数值积分,从而降低模拟的计算成本并减少延迟.最后对比二者学习到的向量场见下图 17图中可见矫正流模型的向量场随时间的变化更稳定这也意味着其轨迹的曲率更低。图17.流匹配模型的曲率可以通过其快速变化的向量场看出来。相比之下矫正流模型会学习到随时间更稳定一致的向量场这表明其轨迹更平直二者均属于基于流的生成模型体系核心目标都是学习从简单源分布到复杂目标分布的连续变换以生成高质量样本且均围绕速度场建模和常微分方程ODE模拟采样展开矫正流更是以流匹配为基础的优化改进方法二者的底层数学框架如概率路径、速度场回归高度关联。以下是具体的相同点与核心不同点4.1 主要相同点核心建模对象一致均不直接建模流而是学习时变速度场通过求解 ODE还原流的轨迹最终实现从源分布如标准高斯到目标分布的样本传输。基础训练逻辑相通都以概率路径插值为基础如线性路径并通过回归损失训练速度场让模型学习匹配真实的速度方向核心均为拟合样本从源到目标的运动规律。采样核心方式相同采样时均通过数值积分算法如欧拉法对训练好的速度场做 ODE 模拟通过离散步骤逼近连续轨迹将源分布的随机噪声转化为目标分布的样本。核心目标一致最终目的都是生成符合复杂目标分布的高质量样本同时兼顾模型的可训练性避免传统流模型的高计算成本问题。4.2 核心不同点对比维度Flow Matching流匹配Rectified Flow矫正流模型定位基础训练方法是矫正流的前置基础流匹配的优化改进方法基于流匹配迭代升级训练所用耦合核心使用独立耦合​​​​​​​源 / 目标样本对独立抽取易实现但会导致路径交叉以模型自诱导耦合替代独立耦合通过前序模型生成无交叉的样本对迭代更新耦合方式训练流程单阶段训练基于固定的独立耦合完成一次速度场训练流程简单多阶段迭代训练Reflow 过程先通过独立耦合训练基础模型再用模型生成的新耦合反复重训直至轨迹拉直学习到的轨迹特性轨迹存在严重曲率因独立耦合导致路径交叉速度场在交叉点被迫平均冲突速度轨迹高度平直自诱导耦合保证轨迹无交叉消除了速度冲突从根本上解决曲率问题采样效率采样效率低弯曲轨迹需大量小步数值积分反复调用大型神经网络计算成本高、延迟大采样效率极高平直轨迹仅需极少步数即可精准模拟大幅减少神经网络调用次数降低成本和延迟速度场特性学习到的速度场随时间变化不稳定同一时空点易存在冲突的速度信号学习到的速度场随时间变化更稳定速度方向一致无冲突信号计算复杂度训练阶段训练复杂度低仅需单阶段回归训练无需多次重训训练复杂度略高需K次迭代训练K为超参数但单次重训的成本与流匹配一致核心解决的问题解决了传统流模型训练需高成本 ODE 模拟的问题让速度场学习通过简单回归实现解决了流匹配因独立耦合导致的轨迹曲率高、采样效率低的核心缺陷是对其工程化落地的优化4.3 核心关联与本质差异总结关联矫正流是流匹配的进阶版本完全继承了流匹配的速度场建模、概率路径、条件流匹配损失等核心技术其第一阶段训练就是标准的流匹配训练。本质差异二者的核心区别在于对 “耦合” 的选择和使用方式—— 流匹配为了训练简便采用独立耦合牺牲了采样效率矫正流则通过迭代生成模型自诱导的无交叉耦合从根本上消除了轨迹曲率的诱因实现了 “训练稍增复杂度采样效率大幅提升” 的优化解决了流匹配大规模部署的核心障碍高延迟、高计算成本。简单来说流匹配是 “把速度场学会”而矫正流是 “把速度场学优让采样更快”。参考资料https://alechelbling.com/blog/rectified-flow/术语解释矫正流Rectified Flows一种优化的流生成模型通过迭代更新耦合方式拉直轨迹实现采样加速流匹配Flow Matching, FM生成模型的训练技术通过回归损失学习速度场无需高成本的常微分方程模拟条件流匹配Conditional Flow Matching, CFM流匹配的改进形式通过对目标样本做条件约束构建易优化的训练目标耦合Coupling源分布与目标分布样本对的联合分布方式是影响流轨迹几何形态的核心因素概率路径Probability Path实现源分布到目标分布平滑插值的连续概率分布序列。

更多文章