更好的模型,更快的训练:用于单细胞基础模型的 Sigmoid Attention
Better Models, Faster Training: Sigmoid Attention for single-cell Foundation Models
训练稳定的生物 foundation model 需要重新思考 attention 机制:我们发现,将 sigmoid attention 作为 softmax attention 的直接替代方案,能够带来以下效果:a) 学到更好的 representation:在六个多样化的单细胞数据集上,sigmoid 实现了高 25% 的细胞类型分离度、更好的细胞类型内聚性指标,以及更低的 validation loss;b) 训练更快,采用 sigmoid attention 的模型相比对应的 softmax 模型训练速度最高提升 10%;c) 训练更稳定,因为它消除了 softmax attention 中固有的不稳定来源。
我们证明,sigmoid attention 的导数具有全局有界性(leq 0.25),而 softmax 不具备这一性质;同时,sigmoid attention 具有对角 Jacobian 结构,不同于 softmax 的密集耦合,这两点共同有助于缓解训练不稳定。在对 160M 参数的双向 attention 模型进行压力测试时,模型在不使用 gradient clipping、8K-token 序列上训练;softmax 出现严重发散,梯度爆炸达四个数量级,而 sigmoid 保持稳定。
最后,我们实现并开源了 TritonSigmoid,这是一个高效 GPU kernel,在 H100 GPU 上达到 515 TFLOPS,性能超过 FlashAttention-2 和 FlashSigmoid,并原生支持 padding,这对生物序列至关重要。我们的结果表明,对于生物 foundation model,sigmoid attention 既有理论依据,也在实证上优于 softmax attention。代码见 https://github.com/MSDLLCpapers/triton-sigmoid