这项由默克公司(Merck & Co., Inc.)剑桥研究团队完成的研究,以预印本形式发布于2026年4月29日,论文编号为arXiv:2604.27124,有兴趣深入了解的读者可以通过该编号在arXiv平台查询完整论文。
(资料图片仅供参考)
研究的核心问题听起来非常专业,但背后的逻辑其实很直观:当一个AI模型在"阅读"一个细胞的基因信息时,它是怎么决定哪些基因之间的关系更重要的?这个"决定方式",也就是所谓的"注意力机制",是整个模型能否准确理解细胞身份的关键。研究团队发现,长期以来被默认使用的那种注意力方式——就像一个人在考场上只能把注意力分给有限几道题——其实并不适合生物数据。他们改用了另一种方式——允许同时全力关注很多事情——结果发现,不仅模型理解细胞类型的能力提升了25%,训练速度也加快了近10%,而且还避免了训练过程中的灾难性崩溃。
这项研究的意义不只在于一个技术指标的提升。单细胞RNA测序数据,简单说就是通过检测每个细胞里哪些基因在工作来理解细胞的"身份",是现代生物医学研究的核心工具之一。基于这类数据训练的AI模型,可以用来自动识别细胞类型、预测药物对细胞的影响、研究疾病发展机制,甚至辅助个性化医疗。研究团队还开源了一个专门为生物数据设计的高效计算内核,让这种更好的注意力方式真正能在实际研究中跑起来。
一、当AI学会"阅读"细胞:背景与挑战
要理解这项研究解决了什么问题,先要知道这类AI模型是怎么工作的。
细胞可以被理解为一个极其复杂的工厂,而基因就是这个工厂里不同车间的开关。一个细胞在某个时刻"表达"了哪些基因、表达量是多少,就像是这个工厂在某个时间节点的运转状态快照。科学家通过单细胞RNA测序技术,能够捕捉到数十万甚至数百万个细胞各自的这种快照。
这类AI基础模型的工作方式,是把一个细胞的基因表达情况当作一段"文字"来读——每个基因是一个"词",整个细胞的基因组合是一个"句子"。模型通过学习海量细胞数据,理解不同基因之间的协同关系,从而学会区分不同类型的细胞。这类模型与ChatGPT这类语言模型在底层结构上高度相似,都依赖一种叫做"自注意力机制"的核心技术。
问题在于,生物数据和语言数据有一个根本性的不同:生物细胞的"句子长度"差异极大。不同细胞表达的基因数量从几百到一万七千多个不等,就像有的人说话只说三个字,有的人一口气说一万七千字。研究团队统计了他们所使用的CellxGene数据集(一个包含1.3亿多个细胞的庞大数据库),发现如果把上下文窗口设置为2048个基因,那么有43%的细胞信息会被截断,就像读一篇文章只读前一半就把后面全部扔掉。而如果想要覆盖96.6%的细胞,窗口至少需要达到8192个基因。
此外,生物序列还有一个特性:每个细胞必须单独处理,不能像文本那样把多篇短文拼凑成一篇长文来凑够固定长度。这就导致批处理时大量的"空白填充",计算资源被严重浪费。
这两个挑战——长序列和大量空白填充——正是这项研究着力解决的核心问题。
二、注意力机制的"竞争性"困境:为什么传统方式在生物数据上表现不佳
现在来理解"注意力机制"到底是怎么回事,以及为什么传统方式在生物数据上会出问题。
传统的注意力机制叫做"softmax注意力"。它的工作原理可以用一个课堂场景来理解:假设一个老师在讲课,课堂上有100个学生(对应100个基因),老师一次只能把100%的注意力分配给这100个学生。如果老师把60%的注意力给了第一排的同学,那剩下的40%就必须分给其他99个人。这是一个"零和游戏"——关注了这个,就必然减少对那个的关注。数学上说,softmax会把所有的注意力分数归一化到一个概率分布上,所有分数加起来必须等于1。
这在语言处理中通常没什么问题,但在基因调控的世界里,情况恰恰相反。一个基因往往同时受到多个转录因子(可以理解为基因调控网络中的"开关管理员")的协同调控,这些调控关系是并行的、独立的,而不是互相竞争的。用竞争性的注意力来模拟这种并行协作关系,就像用一把独木桥来承载一条多车道高速公路的流量,天然不匹配。
更麻烦的是,当序列变长、基因数量达到几千甚至上万时,softmax注意力会出现一种叫做"注意力熵坍缩"的现象——通俗说就是,模型的注意力会越来越极端地集中在少数几个基因上,其他基因几乎被完全忽视。这种极端集中会导致梯度(训练过程中用来调整模型参数的信号)急剧膨胀,最终引发训练崩溃。在之前的单细胞基础模型研究中,这种训练失败的情况相当普遍,造成了大量的计算资源浪费。
研究团队提出的替代方案是"sigmoid注意力"。它的工作原理完全不同:每个基因之间的注意力分数是独立计算的,不需要与其他基因竞争。还是用课堂比喻来说,这就好比老师可以同时全力关注每一个学生,给每个学生的关注度都可以独立达到100%,互不干扰。数学上,sigmoid函数把每对基因之间的关联分数独立映射到0和1之间,而不进行跨基因的归一化。
这种独立性带来了两个直接好处:一是能更真实地模拟基因的并行协同调控关系;二是梯度的传播更加稳定,因为sigmoid函数的导数(可以理解为"信号放大倍数")永远不会超过0.25,而softmax的信号放大倍数会随着注意力分数的增大呈指数级膨胀。研究团队对此做了严格的数学推导,证明sigmoid注意力的雅可比矩阵(一种衡量函数对输入变化敏感程度的数学工具)是对角结构的,而softmax的雅可比矩阵是密集耦合的——前者相当于每条电路独立运行,互不干扰;后者相当于所有电路共用一根总线,一处过载全盘崩溃。
三、让理论落地:专为生物数据设计的高效计算内核
理论上sigmoid注意力更好,但如果实际跑起来慢得像乌龟,那再好的理论也没用。这就引出了研究的第二个核心贡献:一个叫做TritonSigmoid的高效GPU计算内核。
要理解为什么需要专门开发这个内核,先要理解现有工具的局限性。目前最流行的高效注意力计算工具叫做FlashAttention,它专门为softmax注意力优化设计,无法直接用于sigmoid注意力。虽然之前有研究者开发了一个叫FlashSigmoid的工具,但它有两个致命缺陷:第一,不支持序列填充,也就是说同一批次里所有序列必须等长,这在生物数据中几乎不可能实现;第二,不兼容最新的NVIDIA GPU架构。用普通PyTorch(一个深度学习框架)直接实现sigmoid注意力虽然支持填充,但速度极慢,在H100 GPU上只能跑到41 TFLOPS的前向计算速度,相当于只用到了硬件理论算力的一小部分。
研究团队用Triton(一种GPU编程语言,可以理解为专门给GPU写高效程序的工具)从头设计了TritonSigmoid,核心创新点有几处。
第一个创新是"稀疏块计算":对于完全由填充组成的空白块,内核会直接跳过,完全不做任何计算。这就好比考试阅卷时,看到整张答题纸是空白的,直接给0分跳过,而不是逐字逐句去检查有没有答案。这使得在有25%填充的情况下,计算效率损失仅为9.3%。
第二个创新是"融合运算":传统方式需要把注意力矩阵先写到内存里,再读出来进行后续计算;TritonSigmoid将整个注意力计算流程融合成一个连续操作,避免了反复读写内存的开销。对于sigmoid函数本身,内核使用了一个基于tanh(双曲正切)的硬件加速近似公式,利用现代GPU内置的快速tanh运算单元。
第三个创新是"反向传播分解":在训练神经网络时,除了前向计算(模型做预测),还需要反向传播(计算如何调整模型参数)。研究团队将反向传播拆分成两个独立的内核:一个专门计算查询矩阵(Q)的梯度,另一个专门计算键矩阵(K)和值矩阵(V)的梯度。这种分解消除了并行计算中的"原子操作冲突",让不同计算单元能更高效地协同工作。此外,反向传播时需要用到前向传播中间结果,但这些中间结果不需要保存在内存里——内核会在反向传播时重新计算一遍,以此换取内存效率。
最终的性能数据非常突出。在NVIDIA H100 GPU上,TritonSigmoid在16384个token长度、128维度的配置下,前向传播达到了515.6 TFLOPS,后向传播达到373.5 TFLOPS。相比之下,FlashSigmoid是439.7/341.6 TFLOPS,FlashAttention-2是360.6/312.5 TFLOPS,普通PyTorch实现只有92.8/204.8 TFLOPS。换算成相对速度,TritonSigmoid比FlashAttention-2快43%,比FlashSigmoid快17%,比普通PyTorch快5.6倍。
在有25%填充的真实生物数据场景下,TritonSigmoid的优势更加明显,前向传播相比普通PyTorch提速了14.58倍,相比FlashAttention-2也快了29%。由于TritonSigmoid用Triton实现,它天然支持JIT(即时编译)到不同GPU架构,未来面对新一代GPU也不需要重新手写代码,具有很好的前向兼容性。
四、实验验证:sigmoid注意力训练出的模型真的更好吗
计算效率解决了,接下来的核心问题是:用sigmoid注意力训练出来的模型,在理解生物细胞方面真的更好吗?
研究团队训练了四个1.6亿参数规模的模型,分别是:2K上下文窗口+softmax注意力、2K上下文窗口+sigmoid注意力、4K上下文窗口+softmax注意力、4K上下文窗口+sigmoid注意力。所有模型都在相同的CellxGene数据集上训练,使用完全相同的超参数(除了注意力机制本身),都训练到完全收敛,训练过程中都使用了梯度裁剪(一种防止训练不稳定的标准技术)。
评估则在六个完全没有参与训练的独立数据集上进行,覆盖了大脑、血液、结肠、肺和心脏等不同组织,涵盖胚胎期到老年期的不同发育阶段,包含健康、发育和疾病等不同生物学背景。这种多样化的评估设计是为了考察模型的泛化能力,而不仅仅是在训练数据上表现好。
评估指标包含四个维度。第一是验证损失,用来衡量模型预测被遮盖基因的准确程度,类似于"填空题得分",分数越低越好。第二是scIB生物学保守性指标,这是单细胞生物学领域的标准评估框架,包含细胞类型轮廓系数(同类细胞在表征空间中是否聚集在一起)、Leiden聚类的NMI和ARI(无监督聚类结果与真实细胞类型标签的吻合程度),所有这些指标越高越好。第三是UMAP可视化,一种把高维空间投影到二维平面的技术,用于直观观察不同细胞类型的分布格局。第四是最大均值差异(MMD),用来量化不同细胞类型在表征空间中的分离程度,可以理解为"不同细胞类型之间的距离",距离越大越好。
结果显示了两条一致的规律。规律一:在所有六个数据集、两种上下文长度下,sigmoid注意力的验证损失均低于softmax注意力。规律二:4K上下文窗口的表现系统性地优于2K上下文窗口,两种注意力机制均如此,这验证了更长的上下文窗口对捕捉基因关系的重要性。
在生物学保守性指标上,sigmoid在全部六个数据集上都取得了更好的细胞类型凝聚度(轮廓系数),在六个数据集中的四个上取得了更高的综合生物学保守性得分。在心脏流出道(Heart OFT)数据集上进行的MMD分析最为亮眼:研究团队计算了8种细胞类型之间所有28对两两比较的MMD值,sigmoid注意力在全部28对比较中均取得了更高的MMD,平均提升幅度达到25%。这意味着,sigmoid模型学到的细胞表征空间中,不同细胞类型之间的"距离"更大、更容易区分,这对下游的细胞类型分类任务非常有利。
为什么sigmoid能在预测损失相近的情况下学到更好的表征?研究团队的解释是:softmax的竞争性归一化迫使模型在关注一个基因时减弱对其他基因的关注,这种约束可能使模型倾向于关注那些最有预测性的少数基因,而忽略了定义细胞身份所需的复杂基因共表达模式。sigmoid的独立注意力机制允许模型同时、充分地关注多个相关基因,从而更好地捕捉细胞类型的多维度特征,即使这些特征对预测被遮盖基因本身的贡献相对有限。
五、极端压力测试:训练崩溃时,sigmoid能否力挽狂澜
除了正常训练条件下的性能比较,研究团队还设计了一个"极限压力测试",专门用来暴露softmax注意力的稳定性缺陷。
测试条件故意设置得非常苛刻:8192个token的长上下文窗口,完全去掉梯度裁剪保护,其他所有条件与正常训练完全相同,唯一的变量是注意力机制。这就好比把两辆汽车放上赛道,但把安全带和防滚架都拆掉,看谁能跑完全程。
softmax模型的表现是这样的:前40000步训练一切正常,损失从约10稳步降到约3,看起来学习进展顺利。但从第40000步开始,情况急转直下。到第55600步,训练彻底崩溃:损失从3急剧飙升回10以上,全局梯度范数从约100膨胀到1.6×10?,足足增长了四个数量级(也就是膨胀了10000倍)。与此同时,第0层注意力的最大分数从约20暴增到2.3亿。一旦崩溃,训练就再也无法恢复,陷入永久性的发散。
sigmoid模型则从头到尾平稳运行完全部80000步。损失单调下降,从约10稳步降到约3,没有任何异常波动。在第55600步——也就是softmax模型崩溃的那个时间点——sigmoid模型完全没有任何异常,继续稳定训练。全程梯度范数保持在10到100之间,还呈现下降趋势。注意力分数也始终保持在1到5的范围内,与softmax的2.3亿相比,稳定程度天壤之别。
这个结果与前面的理论分析完全吻合:sigmoid的导数上界为0.25,无论输入多大,信号放大倍数永远不超过这个值;而softmax的局部Lipschitz常数(可以理解为最大信号放大倍数)会随注意力分数的增大呈指数级增长,当序列很长、注意力分数很大时,这个放大倍数会达到天文数字,最终引发训练崩溃。
六、训练速度:sigmoid到底快了多少
除了质量更好、更稳定,sigmoid注意力还能让训练更快完成。
研究团队测量了四种模型规模(1.6亿、4亿、6亿、14亿参数)在三种上下文长度(2K、4K、8K)下的端到端训练速度,使用16块H100 GPU进行测量,然后将吞吐量数据换算成处理1.316亿个细胞(完整训练数据)所需的GPU小时数。
在4K上下文长度下,sigmoid比softmax快:4亿参数模型节省5.1%(1739 vs 1832 GPU小时),6亿参数模型节省3.0%(2336 vs 2408 GPU小时),14亿参数模型节省4.0%(4180 vs 4349 GPU小时)。对于14亿参数模型,速度优势随上下文长度递增:2K时快2.1%,4K时快4.0%,8K时快7.5%,仅在8K上下文这一个配置上就节省了645 GPU小时。
对于1.6亿参数模型,研究团队实际完整跑完了训练。在2K上下文下,sigmoid用了653 GPU小时,softmax用了717 GPU小时,sigmoid快了9%;在4K上下文下分别是896和934 GPU小时,sigmoid快了4%。
速度优势随上下文长度增大而增大是符合预期的:注意力计算的复杂度与序列长度的平方成正比,序列越长注意力在总计算量中占的比例就越大,因此sigmoid注意力在计算上的内在简洁性(不需要做跨token的归一化)带来的收益也就越显著。端到端的加速幅度小于内核级别的TFLOPS提升,是因为总训练时间还包括数据加载、非注意力层计算和GPU间通信等其他开销。
归根结底,这项研究给出了一个清晰的答案:在单细胞生物学基础模型的训练中,把softmax注意力换成sigmoid注意力,不是一个"理论上可能更好"的选项,而是一个在多个维度上经过实验验证的务实选择。表征质量更好了,细胞类型识别能力更强了,训练更快了,在极端条件下也不会崩溃。研究团队同时解决了让这种替换真正可行的工程障碍,开发出了支持生物数据特有的长序列和填充需求的高效内核,并将其开源。
生物信息学和AI医疗这两个领域正在以极快的速度融合,越来越多的药物研发、疾病诊断和精准医疗方案将依赖于这类能理解细胞语言的AI模型。如何让这些模型在有限的计算预算内训练得更好、更稳定,是一个具有直接现实影响的工程问题。这项研究提供的解答——一个改动看似微小,但影响深远的注意力机制替换——值得这个领域认真对待。对这项研究感兴趣的读者,可以通过arXiv编号2604.27124查阅完整论文,研究团队的代码也已在GitHub公开。
Q&A
Q1:sigmoid注意力和softmax注意力的核心区别是什么?
A:softmax注意力是"竞争式"的,关注一个基因就必须减少对其他基因的关注,所有注意力分数加起来必须等于1。sigmoid注意力是"独立式"的,每对基因之间的关联分数独立计算,互不影响,可以同时对多个基因保持高度关注。这种独立性更符合基因调控中多个转录因子并行调控同一目标基因的生物学现实,也使得训练过程中的梯度更加稳定。
Q2:TritonSigmoid内核为什么比现有工具快?
A:TritonSigmoid主要通过三个机制实现加速:对完全是空白填充的计算块直接跳过(稀疏块计算)、将注意力计算融合成单一操作避免反复读写内存(融合运算),以及将反向传播拆分成两个独立内核消除并行冲突。此外,sigmoid函数本身比softmax计算更简单,不需要跨token的归一化操作,在硬件层面天然更高效。在H100 GPU上,其峰值前向计算速度达到515 TFLOPS,比FlashAttention-2快43%。
Q3:单细胞基础模型训练为什么容易出现梯度爆炸崩溃?
A:单细胞数据的序列往往很长(可达数千甚至上万个基因),在长序列场景下,softmax注意力的分数会变得非常大,而其局部Lipschitz常数(可理解为信号放大倍数)会随注意力分数的增大呈指数级增长。当这种放大效应在多层网络中叠加传播时,梯度会急剧膨胀,最终超出数值表示范围,导致训练不可逆地崩溃。sigmoid注意力的导数永远不超过0.25,从根源上消除了这种指数级放大的可能性。
内容搜集整理于网络,不代表本站同意文章中的说法或者描述。文中陈述文字和内容未经本站证实,其全部或者部分内容、文字的真实性、完整性、及时性本站不做任何保证或者承诺,并且本站对内容资料不承担任何法律责任,请读者自行甄别。如因文章内容、版权和其他问题侵犯了您的合法权益请联系邮箱:5 146 761 13 @qq.com 进行删除处理,谢谢合作!