寻找复数网络的并行扫描 —— 从 Mamba 的训练哲学到复数序列模型的高效训练

2026-06-14

在一遍遍阅读 Mamba 论文之后,我第一次清晰地意识到:Mamba 训练成功的核心不在于它的选择机制,也不在于它的门控架构。而在于一个更底层、也更数学的事实——它找到了一种将整条序列一次喂入模型、并行算出所有时间步的输出、然后统一反向传播更新参数的方法

这个能力让它和 Transformer 站在一起,和 RNN 分道扬镳。

而我正在探索的复数神经网络,如果不能在这一点上有所突破,就永远会陷入 RNN 的命运:理论上优美,实践上低效。这篇文章是我对这个问题的当前思考的记录。

核心论点:复数线性递推的数学结合律与实数情况完全一致,因此Mamba的并行关联扫描算法在复数域仍然成立;主要障碍不在数学,而在硬件生态——缺乏针对复数的高效融合CUDA算子。

Scope: 本文讨论从 Mamba 训练方式中提炼的"并行扫描"方法论能否迁移到复数神经网络。覆盖数学论证和工程挑战。不讨论特定复数网络架构(如复数 RNN、复数 LSTM 的具体设计)的细节,也不提供完整实现。

Prerequisites: 假定你了解 Mamba/SSM 的基本架构方向,理解 PyTorch 自动求导的基本机制,并熟悉复数的基础运算性质。

一个问题,两种命运

我花了很长时间才真正理解为什么 RNN 和 Mamba 在数学上都是递推,训练效率却有天壤之别。

RNN 的训练困局可以用一句话概括:每前向传播一步,反向传播就必须退一步。对于长度为 $L$ 的序列,这意味着 $O(L)$ 的串行深度。梯度要沿着时间轴倒着走回序列起点,每一个隐藏状态 $h_t$ 都必须被保存下来供反向传播使用。结果是 $O(L \cdot d^2)$ 的内存和无法并行的计算图。

Mamba 的训练方式完全不同。它走完 $L$ 步前向传播,算出总损失,然后统一做一次反向传播。乍看起来这好像和 RNN 没什么区别——不都是走完再更新吗?但关键在于 Mamba 不走那 $L$ 步串行循环。它用并行关联扫描在一棵深度 $O(\log L)$ 的树里一次性算出所有 $h_1, \dots, h_L$,然后用重计算技术在反向传播时现场重新生成这些状态,避免将它们写入全局显存。

我读到这个机制的时候,心里冒出一个问题:复数版本呢?

复数神经网络的价值在理论上已经被反复论证——复数表示的欠阻尼动态系统有更丰富的表达力、复数递推对旋转和振荡建模更自然、复数梯度流具有与实数不同的特性。但所有这一切,都被训练效率的阴影笼罩着。如果复数 RNN 的训练还是只能 $O(L)$ 串行,它在实际应用中的天花板就是一层天花板。

但等一下。如果 Mamba 的并行扫描依赖的是递推方程的线性性质和由此带来的结合律,而不是递推系数的实数性,那么复数递推方程满足同样的性质。这意味着——至少在数学上——并行扫描对于复数网络是直接可用的。

并行扫描的数学本质

Mamba 能并行的根本原因不在硬件优化,不在 CUDA 编程,而在一个极其朴素的数学性质:结合律

我们来写出 Mamba 的递推方程:

$$h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$$

这个方程里,$\bar{A}_t$ 是一个 $N \times N$ 矩阵,$\bar{B}_t x_t$ 是一个 $N$ 维向量。我们可以把"从 $t-1$ 到 $t$ 的一步转移"定义为一个操作符 $O_t = (\bar{A}_t, \bar{B}t x_t)$。这个操作符作用于 $h{t-1}$ 的方式是:

$$O_t(h_{t-1}) = \bar{A}t h{t-1} + \bar{B}_t x_t$$

关键来了:相邻两步可以合并。将 $h_t = O_t(h_{t-1})$ 代入 $h_{t+1} = O_{t+1}(h_t)$,得到:

$$h_{t+1} = (\bar{A}{t+1} \bar{A}t) h{t-1} + (\bar{A}{t+1} \bar{B}t x_t + \bar{B}{t+1} x_{t+1})$$

注意结果保持了完全相同的数学形式——一个矩阵乘前一个状态,再加一个向量。因此我们可以定义合并操作 $\otimes$:

$$O_{t+1} \otimes O_t = (\bar{A}{t+1} \bar{A}t, ; \bar{A}{t+1} \bar{B}t x_t + \bar{B}{t+1} x{t+1})$$

这个 $\otimes$ 满足结合律:$(O_3 \otimes O_2) \otimes O_1 = O_3 \otimes (O_2 \otimes O_1)$。这意味着我们可以用树状归约一次性合并所有时间步的操作符,然后用 $O(\log L)$ 深度的并行计算出所有 $h_t$。

这个推导自始至终没有用到任何实数特有的性质。

如果 $\bar{A}_t$ 和 $\bar{B}_t$ 是复数矩阵,$h_t$ 是复数向量,$x_t$ 可能是复数输入,那么:

  • 复数矩阵乘法 $\bar{A}_{t+1} \bar{A}_t$ 满足结合律
  • 复数向量加法 $\bar{A}{t+1} \bar{B}t x_t + \bar{B}{t+1} x{t+1}$ 结合律也成立
  • 整个 $\otimes$ 操作符的推导一行都不需要改

我反复确认了几遍,结论是一样的:在数学层面,复数线性递推的并行扫描条件完全满足。

一个重要细节:选择性机制的复数版本

Mamba 的选择性机制让 $\bar{A}_t$ 和 $\bar{B}_t$ 成为输入 $x_t$ 的函数。在实数域,这通过线性层 $s_B(x_t) = \text{Linear}_N(x_t)$ 等来实现。

对于复数网络,我目前认为有两种方案:

方案一:复数值线性层。将标准线性层的权重和偏置替换为复数版本,直接输出复数值的 $\bar{B}_t$ 和 $\bar{C}_t$。这样做最直接,但需要实现复数版本的线性层及其反向传播。

方案二:实数映射保持。参数保持为实数,通过将实数参数重新解释为复数的实部和虚部来构造复数参数,或者用 $2N$ 维的实数向量表示 $N$ 维复数向量。这种方式不需要修改现有框架的复数支持,但理论上丧失了复数表示的一些结构优势。

我倾向于方案一,因为复数线性层的实现本身并不复杂(只是一个复数矩阵乘法),而且它保持了复数网络的全部表达力。但方案二的一个实际优点是:它可以直接复用现有的实数优化器和 CUDA 库。

这里的 $\Delta_t$ 离散化步长处理更微妙。Mamba 中 $\Delta_t$ 是一个正实数(通过 $\text{softplus}$ 确保正性),经过 ZOH 离散化:

$$\bar{A}_t = \exp(\Delta_t A)$$

如果 $A$ 是复数矩阵,$\exp(\Delta_t A)$ 在数学上定义良好(矩阵指数),但在底层实现上需要复数矩阵指数运算,这比实数版本贵不少。这不是一个原则性障碍——复数矩阵指数已经被广泛研究和实现——但它意味着在实际训练中不能直接照搬 Mamba 的 CUDA 核。

硬件落地的挑战

数学上没有问题,落地就完全不是一回事了。我目前能看到三个层面的障碍,按困难程度排序。

第一层:CUDA 内核支持。 Mamba 的并行扫描之所以快,是因为它有一个定制的融合 CUDA 核(fused kernel),将离散化、并行扫描、输出投影合并为一个操作,全部在 SRAM 中完成,避免 $(B, L, D, N)$ 的中间张量写入 HBM。

对于复数版本,我需要的基础操作(复数矩阵乘法、复数向量加法)本身在 CUDA 中都是支持的——cuBLAS 的 cgemmzgemm 从第一代 CUDA 就存在。问题在于 融合。Mamba 的核不是简单的 Scan,而是将多个复数域操作串联成一个融合操作。现有的 cuBLAS 函数做不到这种自定义的融合,我需要自己写 CUDA 核,以及对应的反向传播核和重计算逻辑。

我不确定这样做的开发成本有多高。但有一个折中方案:退回到使用 PyTorch 的 associative_scan 或类似的高层接口。torch.vmaptorch.scan 可能提供一部分支持,但效率一定比不上定制核。对于原型验证(proof of concept),这已经够了。

第二层:离散化的复数处理。 Mamba 使用 ZOH(零阶保持)离散化:

$$\bar{A}_t = \exp(\Delta_t A), \quad \bar{B}_t = (\Delta_t A)^{-1}(\exp(\Delta_t A) - I) \cdot \Delta_t B_t$$

当 $A$ 是复数矩阵时,$\exp(\Delta_t A)$ 的计算代价高于实数版本。一个常用的简化是使用欧拉离散化 $\bar{A}_t = I + \Delta_t A$,这在 $|\Delta_t A| \ll 1$ 时效果不错。对于复数网络,这个近似是否成立取决于特征值 $\lambda(A)$ 的模长和 $\Delta_t$ 的取值范围。

另一种思路是使用双线性变换(Tustin’s method),它在复数域同样有效,且数值稳定性优于欧拉法。

第三层:复数的反向传播。 PyTorch 对复数自动求导的支持已经相当成熟(从 1.9 版本开始),支持 torch.complextorch.complex128 的梯度计算。但我没有测试过它对复杂计算图(如并行扫描的重计算模式)的处理效率。自定义反向传播函数(torch.autograd.Function)可能是必要的。

总结一下硬件层面的判断:prototype 可行,性能优化的门槛不低,但不存在原则性障碍。

值得借鉴的现有工作

在做这个调研之前,我以为复数序列模型的高效训练是一个冷门方向。查了一圈发现情况比想象中好。

最直接相关的是一系列 **复数状态空间模型 Complex-valued SSM ** 的工作。Mamba-3 和最新的一些 SSM 变体已经显式地使用了复数值状态。Ran-Milo & Cohen (NeurIPS 2024) 更从理论上严格证明了复数参数化对 SSM 的表达优势:复数 SSM 在中等维度下就能表达实数 SSM 的所有映射,但反之实数 SSM 需要维度达到 $O(t)$($t$ 为时间长度);即使维度够高,实数 SSM 的参数值也可能需要指数级大才能表达某些振荡映射,使得它们无法在实际中学习。这为复数 SSM 的研究提供了坚实的理论基础。Mamba 论文本身提到:对于连续模态(音频、视频),复数状态比实数表现更好;对于离散模态(文本、DNA),实数反而更优。这说明复数 SSM 在特定场景下是有明确需求的。

这些工作中的 CUDA 实现主要集中在解决复数矩阵指数和复数并行扫描的数值稳定性问题。我没有找到直接对标 Mamba 融合核的复数版本,但有一些零散的 CUDA 部件可以借用。

另一个方向更间接但同样重要:复数 RNN 的训练加速。虽然传统复数 RNN 没有 Mamba 的并行扫描能力,但最近的工作尝试用梯度裁剪、正交初始化、复数值动量等方法缓解复数 RNN 的训练困难。这些工作虽然不改变串行训练的本质,但提供了一些在无法彻底并行化时的实用改进。比如对复数 RNN 使用基于 Wirtinger 微积分的梯度下降,其收敛性质比简单的实数梯度拆分要好。

还有一个值得关注的线索是复数 FFN 和复数卷积的高效实现。在一些复数图像分类的工作中,“复数计算太慢"的问题通过 cuFFT(快速傅里叶变换)的复用得到了部分缓解——复数卷积用 FFT 计算天然比实数卷积更高效,因为 FFT 本身就是复数运算。虽然这和序列模型的并行扫描不是同一个东西,但它说明复数计算在某些场景下反而有硬件优势。也许复数 SSM 在某些配置下也能用 FFT 加速卷积模式,结合并行扫描获得双重效率提升。

Where this breaks

这个论证目前有几个明显的薄弱环节。我在写的时候就意识到这些问题,写下来供未来自己验证或推翻。

  • 复数矩阵指数是昂贵的。 ZOH 离散化中的 $\exp(\Delta_t A)$ 对于复数矩阵的计算代价远高于实数,特别是在状态维度 $N$ 较大且 $\bar{A}_t$ 每步不同(选择机制)的情况下。欧拉近似或双线性变换可能是必要的简化,它们的近似误差需要被量化。

  • 选择性机制和静态结合的兼容性未验证。 我上面的推导假设了给定输入序列后,$\bar{A}_t$ 和 $\bar{B}_t$ 被确定下来,然后并行扫描对这些确定的(可能是复数的)系数进行合并。但选择性机制本身(从 $x_t$ 求 $\bar{A}_t, \bar{B}_t$)是否在复数域同样稳定?复数线性层的训练收敛性质与实数不同,这可能影响整个流水线。

  • 硬件效率的定量估算缺失。 我说"prototype 可行,性能优化有门槛"但给不出数字。一个完整的复数 SSM 层和实数 SSM 层的吞吐量对比,以及其中离散化步骤的开销占比,需要实际跑实验才能知道。

  • SRAM 容量限制可能会更紧。 复数占用双倍于实数的存储(float32 vs complex),在 SRAM 这个稀缺资源上,这是一个紧约束。Mamba 的融合核成功的关键之一是状态维度 $N$ 和特征维度 $D$ 的乘积能放进 SRAM。复数版本用 complex 会让占用翻倍,这意味着更小的 $N$ 或 $D$,或者更小的块大小。

Open questions

写完之后留下的问题比写之前更多。列在这里,希望六个月后的我能回答其中的一些。

Q1: 复数并行扫描的 CUDA 融合核开发成本到底有多大?能否基于现有的开源 SSM 代码库(如 state-spaces 仓库)修改,还是需要从零编写?

Q2: 对于复数状态 $h_t$ 而言,离散化误差和数值稳定性与实数版本相比如何?特别是当 $\bar{A}_t$ 的特征值靠近单位圆时。

Q3: 在复数域,Mamba 的选择机制是否仍然有效?或者说,额外的复数自由度是否会带来选择性的一种"自然涌现”?

Q4: 是否有某些任务(如音频合成、振荡动力学建模、量子系统模拟)中,复数 SSM 的 O(log L) 并行训练能力能够带来现实可测量的收益?

[[Q]] 六个月后回看:复数并行扫描的数学可行性我已经确认了,但你实际动手验证了吗?第一个基线实验跑出了什么结果?

References

  1. Gu, Dao, “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”, 2023.
  2. Gu, Goel, Ré, “Efficiently Modeling Long Sequences with Structured State Spaces (S4)”, ICLR 2022.
  3. Gu et al., “HiPPO: Recurrent Memory with Optimal Polynomial Projections”, NeurIPS 2020.
  4. Blelloch, “Prefix Sums and Their Applications”, 1990.
  5. Gu et al., “Modeling Sequences with Structured State Spaces”, Stanford PhD Dissertation, 2023.
  6. Arjovsky, Shah, Bengio, “Unitary Evolution Recurrent Neural Networks”, ICML 2016.
  7. Wisdom et al., “Full-Capacity Unitary Recurrent Neural Networks”, NeurIPS 2016.
  8. Trabelsi et al., “Deep Complex Networks”, ICLR 2018.
  9. Ran-Milo, Cohen et al., “Provable Benefits of Complex Parameterizations for Structured State Space Models”, NeurIPS 2024.
parallel-scancomplex-valuedmambatraining-efficiencyrecurrenceassociative-scan

Two ways to diffuse text: DFlash block diffusion vs DiffusionGemma