随着大模型在应用层面的发展,支撑大模型的底层架构模型Transformer所存在的各种问题也逐渐显现,业内一直都在寻找Transformer框架的替代方法。有在原Transformer架构基础上微调改良的,也有采用传统RNN模型的思想的架构,还有以CNN为基础模型的架构,更有将Transformer和其他RNN、CNN架构结合的混合架构模型。无论模型如何变化,目的都是为了更高效地完成任务。
目前的大模型的基础架构改良和重设计,都是在三大基础架构之上进行的革新,即以FNN为主的稠密空间架构基础、以CNN为主的稀疏空间架构基础、以RNN为主的时间序列架构基础。当然一般在设计模型的时候会还会混合加入一些其他架构,比如Encoding-Decoding、Attention、GatedUnit等,Transformer就是在FNN的基础之上加入了Encoding-Decoding、Attention,以及PositionalEncoding等这些元素共同组成的大模型基础架构。
以下是针对目前比较热门的、可能会在未来替代Transformer的基础模型的盘点分析,从中可以发现对于大模型的基础架构模型的设计,主要还是以FNN、CNN、RNN的思想为锚点,加入一些类注意力机制、门控机制、状态空间机制等设计trick继续扩展的。
一、RetNet(Retentive Network)
RetNet是微软研究院提出的一种新型自回归基础架构。RetNet在某种程度上借鉴了Transformer的思想,但它并非直接基于Transformer,而是提出了一种新的机制和架构,在新的架构中引入了一种名为多尺度保留(Multi-ScaleRetention,MSR)的机制来替代Transformer中的多头注意力机制。
论文通过大量实验对比了RetNet和Transformer及其变体。实验结果表明,RetNet在scaling曲线和上下文学习方面一直表现出色。此外,RetNet的推理成本与序列长度无关。对于7B模型和8k序列长度,RetNet的解码速度是带键值缓存的Transformer的8.4倍,内存利用率提高了70%。
在训练过程中,RetNet还能够比标准Transformer节省25-50%的内存,实现了7倍的加速,并在高度优化的FlashAttention方面具有优势。此外,RetNet的推理延迟对批大小不敏感,因此实现了巨大的吞吐量。
RetNet由L个相同的保留层堆叠而成,每个保留层包含一个多尺度保留模块(MSR)和一个前馈神经网络(FNN)模块。这种结构使得RetNet在训练可并行、推理成本低和良好的性能等方面具有优势。
MSR模块的主要功能就是保留机制。首先,对于输入序列的每个元素,MSR模块会计算它与序列中其他所有元素的关系。然后,再将这些关系以一种衰减的方式保留下来,这就是所谓的“保留”机制。
此外,MSR模块还引入了位置相关的指数衰减项来取代softmax,简化了计算,同时使前步的信息以衰减的形式保留下来。并且,保持机制使用多尺度的衰减率,增加了模型的表达能力。
RetNet被设计为大型语言模型的基础架构,RetNet的主要优势在于它能够同时实现训练并行化、低成本推理和良好的性能。RetNet提出了一种名为"retention"的机制来替代传统的"attention"机制。这种机制支持三种计算范式,即并行、循环和分块循环。具体来说,并行表示允许训练并行化,循环表示使得推理成本低,而分块循环表示有助于有效地进行长序列建模。
RetNet的复杂性随着模型大小的放大而降低。根据经验观察到,当模型尺寸大于2B时,RetNet的性能往往优于Transformer。
Transformer和RetNet的推理成本,模型尺寸为6.7B。RetNet在内存消耗、吞吐量和延迟方面都优于Transformer。托管RetNet需要更少的GPU内存来托管RetNet。RetNet的额外内存消耗几乎可以忽略不计(即约3%),而模型权值占97%。吞吐量如下图b所示,Transformer的吞吐量随着解码长度的增加而下降。相比之下,RetNet在解码过程中具有更高的长度不变吞吐量。延迟延迟是部署中的一个重要指标,它极大地影响了用户体验。
实验结果表明,随着批量尺寸的增大,Transformer的延迟就越大。此外,随着输入时间的延长,Transformer的延迟增长得更快。为了使延迟可接受,就必须限制批处理大小,这损害了Transformer的整体推理吞吐量。相比之下,RetNet的解码延迟优于Transformer,并且在不同的批处理大小和输入长度上几乎保持不变。
在训练阶段,RetNet的多尺度保留(Multi-ScaleRetention,MSR)机制支持并行计算,这意味着可以同时处理所有的输入数据,从而大大提高了训练效率。在推理阶段,RetNet的设计使得可以循环地进行推断,这意味着在每个时间步,都可以利用前一步的输出作为当前步的输入,从而实现了低成本的推断。
RetNet的设计不仅提高了训练效率,还大大简化了推断过程。从RetNet的并行训练和循环推理可以发现它实际上是RNN和Transformer核心原则的融合:即REcurrent(循环)+self-attenTION(自注意力)=RETENTION(保留)。
整体来说,RetNet是借鉴了RNN和Transformer两者的优势,配备了三种处理范式——并行训练、循环和块状推理。它采用了Transformer的可并行化自注意力机制,也采用了一些非常巧妙的技巧,避免了Transformer的某些缺陷。
对RetNet而言,尽管它使用Transformer的自注意力块来并行训练并达到最先进的性能,但它并不受到推理成本和内存复杂性问题的影响。这归因于它调整过的自注意力模块,它用保留模块替换了该模块,加上它使用的循环推理范式,在推理时可以模仿自注意力。
二、RWKV(Receptance Weighted Key Value)
RWKV模型是一种基于Transformer结构的模型,由香港大学物理系毕业的彭博首次提出。
Transformer几乎彻底改变了所有的自然语言处理(NLP)任务,但由于内存和计算复杂度的影响,它们与序列长度成二次增长。相比之下,递归神经网络(RNNs)在内存和计算需求上表现出线性伸缩性,但由于并行化和可伸缩性方面的限制,很难匹配与Transformer相同的性能。因此作者提出了一种新的模型架构,接收加权键值(RWKV),它结合了Transformer的高效并行训练和rnn的高效推理。
RWKV的名字来源于timemixing和channelmixing之中的四个重要概念:
R:The Receptance vector acts as the receiver of past information.类似于LSTM的“门控单元”,用于接收以往信息。
W:The Weight signifies the positional weight decay vector, a trainable parameter within the model.位置权重衰减向量,是可训练的模型参数。
K:The Key vector performs a role analogous to K in traditional attention mechanisms.键向量,与传统自注意力中K的向量类似。
V:The Value vector functions similarly to V in conventional attention processes.值向量,与传统自注意力中V的向量类似。
在模型设计过程中,作者利用了一种线性注意机制,并允许将模型表示为Transformer或RNN,从而在训练过程中并行化计算,并在推理过程中保持恒定的计算和记忆复杂性。作者的将模型扩展到140亿个参数,这是迄今为止训练过的最大的密集RNN,这表明未来的工作可以利用这种架构来创建更有效的模型。这项工作为协调序列处理任务中的计算效率和模型性能之间的权衡提供了重要的一步。
RWKV模型是由堆叠的残余块组成的。每个块由一个Time-Mixing(时间混合)和一个Channel-Mixing(通道混合)子块组成,体现了循环结构来利用过去的信息。每个block包含两个主要部分:Time-Mixing和Channel-Mixing。在这两部分中,模型名字中的R,K,V在Time-Mixing里用到了,R,K在Channel-Mixing里用到了。
在这种结构中,所有涉及计算的线性投影向量(时间混合中的R,K,V,以及通道混合中的R',K',)都是通过当前和以前的时间步长输入之间的线性插值产生的,促进了token位移。
Time-Mixing:时间混合计算的向量是块的当前和先前输入的线性组合的线性投影:这部分的数学公式非常清楚,非常直观。r、k、v的计算与传统Attention机制类似,通过将当前输入与前一时刻输入做线性插值。
Channel-Mixing:Channel-Mixing的r和k和time mixing里面的r和k不同,都要重新算一遍,这里用r'和k'来表示。channel mixing的意思就是在特征维度上做融合。假设特征向量维度是d,那么每一个维度的元素都要接收其他维度的信息,来更新它自己。特征向量的每个维度就是一个“channel”(通道)。
从公式上来看,Time-Mixing与Channel-Mixing这两个模块有些类似RNN的Ot=f(Xt,Ot-1)形式。
在RWKV模型中,WKV算子的计算与无注意力Transformer(AFT)中使用的方法相似。然而,与AFT不同的是,W是一个成对矩阵,RWKV模型将W视为一个由相对位置修改的通道级向量。在RWKV模型中,这种循环行为是由WKV向量的时间依赖性更新来定义的,为了避免W的任何潜在退化,公式中引入了一个单独关注当前token的向量U。具体公式为下式:
这里简单的回忆一下Transformer中的单头自注意力的公式(为方便而省略了多头和缩放因子1/√dk):
将上式的softmax公式展开后,公式的形式则便成为为下式(核心的QK乘法是序列中每个token之间的成对注意分数的集合,可以分解为向量运算):
免注意力Transformer模型(Attention Free Transformer,AFT)不同于Transformer中原始的attention,无注意力Transformer(AFT)是一个替代MHA的插件,而不需要改变Transformer的其他架构方面。给定输入X,AFT首先将它们线性转换为Q = XWQ,K = XW K,V = XWV,然后执行以下操作:
其中是元素级的乘积;σq是应用于查询的非线性;w∈RT×T是学习到的成对位置偏差。对于每个目标位置t,AFT执行值的加权平均值,其结果与元素级乘法查询相结合。特别是,权重只是简单地由键和一组学习到的成对位置偏差组成。这提供了一个直接的优点,即不需要计算和存储昂贵的注意矩阵,同时像MHA那样保持查询和值之间的全局交互。
在模型设计上,RWKV采用了将Transformer的高效可并行训练与RNN的高效推理相结合的方式。利用了线性注意力机制,并允许将模型制定为Transformer或RNN,从而使得训练期间可以并行化计算,且在推理期间保持恒定的计算和内存复杂度,这诞生了第一个被扩展到数十亿参数的非Transformer架构。
RWKV结合了RNN和Transformer的最佳特点。在训练过程中,RWKV采用了Transformer类别的架构,支持大规模并行化,具有按token数线性缩放的注意力机制。对于推理过程,RWKV采用了与带有隐状态向量的RNN等效的设计。这种设计使得RWKV能够既继承了Transformer的优势,又能够消除长上下文长度的计算存储开销。
因此,相当于实际上拥有了一个在训练时类似于Transformer的模型,只是处理长上下文长度时的计算开销不高。在推理期间,模型需要的存储开销更小,并且可以隐式处理“无限”上下文长度(尽管在实际应用中,模型可能难以泛化到比训练中更长的上下文长度)。
通过对典型计算平台上的文本生成速度和内存需求的测试评估,包括CPU(x86)和GPU(NVIDIA A100 80GB)的内存需求。得出一个结论,那就是RWKV为LLM生成文本的累积时间与Transformer不同,RWKV随着token的增加在生成时间上表现出线性缩放。这一发现表明RWKV在生成文本效率方面是由于传统Transformer模型的。
实验表明,RWKV的设计精良,能够缓解Transformer所带来的内存瓶颈和二次方扩展问题,实现更有效的线性扩展,同时保留了使Transformer在这个领域占主导的一些性质;RWKV的性能与同样大小的Transformer相当,这表明未来的工作可以利用这种架构创建更有效的模型。这项工作是在解决序列处理任务中的计算效率和模型性能之间的权衡方面迈出的重要一步。
三、Mamba(这是一条会吐丝的曼巴蛇)
Mamba基于选择性状态空间模型(selectivestatespacemodel),该架构是Mamba论文作者AlbertGu此前主导研发的S4架构(StructuredStateSpacesforSequenceModeling)的一个简单泛化。可以有选择地决定关注还是忽略传入的输入。
至于模型的名字为什么起名为Mamba?作者AlbertGu表示主要是由于Mamba具有以下特征:
-速度快:序列长度线性缩放的简单递归与硬件感知设计的实现
-致命性:对序列建模问题具有致命的吸引力
-就连发出的「声音」都很像:其核心机制是结构化状态空间序列模型(S4)的最新演进——SSSS
Mamba提出了一种新的选择性状态空间模型,它改进了之前在几个轴上的工作,以实现Transformer在序列长度线性扩展的建模能力。主要技术体现在以下三方面:
1、选择机制。首先,确定先前模型的一个关键限制:以依赖于输入的方式有效地选择数据的能力(即关注或忽略特定的输入)。基于选择性复制和归纳头等重要合成任务的直觉,设计了一个基于输入的SSM参数的简单的选择机制。这允许模型过滤掉不相关的信息,并无限期地记住相关的信息。
2、硬件感知算法。这种简单的变化对模型的计算提出了技术挑战;事实上,所有先前的ssm模型必须是时间和输入不变的,才能提高计算效率。通过一种硬件感知算法来克服这一问题,该算法通过扫描而不是卷积反复计算模型,但没有实现扩展状态,以避免GPU内存层次结构的不同层次之间的IO访问。所得到的实现在理论上比以前的方法更快(与所有基于卷积的ssm的伪线性扩展相比)和在现代硬件上(在A100gpu上快3倍)。
3、主要架构。将先前的SSM体系结构与变压器的MLP块结合成一个块,简化了先前的深度序列模型体系结构,从而形成了包含选择性状态空间的简单而同质的体系结构设计(Mamba)。
以前的递归模型的缺点是它们的固定大小状态难以压缩上下文,Mamba的主要特点是引入了选择性SSM,这个小小的改变-只是让某些参数成为输入的函数-让它立即解决以前模型难以解决的有趣任务。它可以无限期地推断出重要的“关联召回”任务的解决方案!
结构化ssm通过高维潜在状态h(例如N = 4)独立地将输入x的每个通道(例如D = 5)映射到输出y。先前的ssm通过需要时不变性的巧妙替代计算路径来避免实现这种大型有效状态(DN,倍批大小B和序列长度L):(∆,a, B, C)参数随时间不变。我们的选择机制增加了依赖输入的动态,这也需要一个谨慎的硬件感知算法,只在GPU内存层次的更高级的客户端级别中实现扩展状态。
Mamba引入的「选择性SSM」是S4的简单泛化,可以选择性地关注或忽略输入。类似plus版本的门控机制,可以选择性的遗忘不需要的信息。状态空间模型(SSM)的特征,可以理解为类似RNN中固定大小的细胞状态。如果想实现更好的性能,就要求这种状态更大,并且更具表现力。不过需要注意的是,较大的状态会导致模型变慢。
具体地说,S4是一类用于深度学习的序列模型,与RNN、CNN和经典的状态空间模型(StateSpaceModel,SSM)广泛相关。SSM是独立的序列转换,可被整合到端到端神经网络架构中(SSM架构有时也称SSNN,它与SSM层的关系就像CNN与线性卷积层的关系一样)。
Mamba模型之所以能够媲美Transformer甚至比其性能更加优秀,主要在于它的选择性SSM层能够选择性地记住相关的token,同时忽略中间其他的token,因此能够完美地完成任务。
在任务处理过程中,Mamba会逐个观察token,然后改变隐藏状态,每次看到一个新token时都会更新隐藏状态。从某种意义上来说,这模仿了人脑处理信息的方式,就像阅读一句话或一段话,就像在大脑中存储一些信息。当你读完一个文档时,可能能够回答关于那个文档的问题,而无需再次参考该文档。所以,RNN就是这样工作的。它们处理文本,然后改变隐藏状态,隐藏状态是可以用来生成新token或对文档进行分类的表示。
Transformer存在二次方扩展特性,它在自注意力机制下的计算量会随着上下文长度的增加呈平方级增长,每个token都会与之前的各个token进行比较。比如上下文增加32倍时,计算量可能会增长1000倍,Mamba可以随上下文长度的增加实现线性扩展,其性能在实际数据中可提高到百万token长度序列,并实现5倍的推理吞吐量提升。而这些,都离不开选择性SSM。
相比Transformer随着输入token的增加,参数会呈现指数型的增加,Mamba参数的增加则随着token的增加呈线性增长的,这可以为模型的训练和应用节省更多的成本。
Transformer的注意力模块在模型能力表达方面确实非常有效,但是在另一些方面又很低效,因为它完全不压缩上下文。从这一点可以看出,自回归推理需要显式存储整个上下文(即KV缓存),这直接导致Transformer的线性时间推理和二次训练时间非常缓慢。
相比Transformer这类随着toeken的增加,模型的状态空间也不断增加的模型,以RNN思想为主的Mamba模型的状态是稳定的,这意味着推理时间是恒定的,并且训练的时间也将会是线性的。
Mamba的单块设计结合了H3块,这是大多数SSM架构的基础,具有现代神经网络中无处不在的MLP块。Mamba的设计上没有交错这两个块,而是简单地均匀地重复Mamba块。与H3块相比,Mamba用激活函数取代了乘法门。与MLP块相比,Mamba在主分支中添加了一个SSM。对于激活函数,使用了SiLU / Swish。
Mamba通过让SSM参数作为输入的函数,解决了其离散模态的弱点,允许模型根据当前token选择性地沿序列长度维度传播或忘记信息。Mamba具有快速的推理(吞吐量比Transformer高5倍)和序列长度线性缩放。在语言建模任务中,Mamba-3B模型在预训练和下游评估中均优于相同规模的Transformer,并且与其两倍大小的Transformer模型相媲美。
作为一个通用的序列处理模型,Mamba在语言、音频和基因组学等多个领域都获得了最先进的性能表现。在语言建模方面,Mamba-3B模型在预训练和后续评估中性能达了两倍参数量的Transformer模型性能。在音频波形和DNA序列建模方面,Mamba的表现优于SaShiMi、Hyena和Transformer等先前的SOTA模型。
以上的RetNet、RWKV、Mamba是有可能会在未来替代Transformer的模型架构,下一篇文章我们将继续剖析其它有可能会替代Transformer的模型架构。