深入理解transformer


1.动机

2.transformer算法讲解

2.1总体架构

深入理解transformer-2024-07-21-15-38-31
如图所示,transformer首先将输入进行Word Embedding编码,然后通过位置编码加入位置信息。后续部分分为N个编码器和N个解码器
编码器内部由多头注意力机制前馈神经网络组成,解码器内部也由前馈神经网络多头注意力机制组成,但是加入了过去输出编码后通过掩膜多头注意力机制的输入。解码器输出的结果经过线性层和softmax得到最后的输出结果的概率。

2.1.1 Word Embedding(词嵌入)

由于神经网络能接受的输入只能是数字,所以需要一种方法将单词进行编码。如下图所示
深入理解transformer-2024-07-21-15-58-20
词嵌入就是将「不可计算」「非结构化」的词转化为「可计算」「结构化」的向量,且将其映射到一个低维、连续的向量空间。
如图所示,将单词作为token(最小语义单元),对how,are,you分别进行编码,如果认为每个单词之间都没有任何关系就可以采用独热编码,N个单词编码为N个维度中互不相交的向量。但是显然单词之间是有关系的,如下图
深入理解transformer-2024-07-21-16-02-43
那么如何初始化这些词向量的值具体可以选择随机初始化或者word2vec方法(用神经网络去预测上下文来学习词嵌入)。方法关系如下图
深入理解transformer-2024-07-21-16-17-00

2.1.2 位置编码

transformer不像RNN那样一个词一个词的输出,而是通过位置编码来加入位置信息。

ppos,2i=sin(pos100002i/d)ppos,2i+1=cos(pos100002i/d)\begin{aligned} p_{pos, 2 i} & =\sin \left(\frac{pos}{10000^{2 i / d}}\right)\\ p_{pos, 2 i+1} & =\cos \left(\frac{pos}{10000^{2 i / d}}\right) \end{aligned}

其中,pos表示token在序列中的位置,i表示位置向量里的第i个元素,d表示嵌入维度。
之所以用正余弦函数,因为其线性性质和周期性(相对位置重要性)。
位置嵌入矩阵PP,词嵌入矩阵EE,则输出X=X+PX = X + P

2.1.3 编码器

2.1.3.1 多头注意力机制

深入理解transformer-2024-07-21-20-20-25
首先是子注意力机制计算,注意力汇聚公式如下

f(q,K,V)=i=1nα(q,k)vf(q,K,V)=\sum_{i=1}^n \alpha\left(q, k\right) v

其中,qq为查询,(K,V)(K,V)为键值对,α(q,K)\alpha(q,K)为注意力权重。
深入理解transformer-2024-07-21-20-23-38
采用注意力评分函数a()a()对查询和键之间进行建模。

α(q,ki)=softmax(a(q,ki))\alpha\left(\mathbf{q}, \mathbf{k}_i\right)=\operatorname{softmax}\left(a\left(\mathbf{q}, \mathbf{k}_i\right)\right)

深入理解transformer-2024-07-21-20-49-43
在transformer中的注意力权重计算公式如下。

 Attention (Q,K,V)=softmax(QKTdk)V\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V

在transformer中采用自注意力机制的思想得到QKV。
深入理解transformer-2024-07-21-21-00-20
然后将所有得到的注意力输出连接起来进入一个线性层得到一个最终的值就是多头注意力机制。

MultiHead(Q,K,V)=Concat( head 1,, head h)WO where head i=Attention(QWiQ,KWiK,VWiV)\begin{aligned} & \operatorname{MultiHead}(Q, K, V)=\operatorname{Concat}\left(\text { head }_1, \ldots, \text { head }_{\mathrm{h}}\right) W^O \\ & \text { where head }_{\mathrm{i}}=\operatorname{Attention}\left(Q W_i^Q, K W_i^K, V W_i^V\right) \\ & \end{aligned}

2.1.3.2 残差和LayNorm
  • 残差网络
    引入resnet的思想(残差网络)为了让网络可以做深。
    深入理解transformer-2024-07-21-21-08-50
    如图,有A,B,C,D个网络

XDin =XAout +C(B(XAout ))X_{\text {Din }}=X_{\text {Aout }}+C\left(B\left(X_{\text {Aout }}\right)\right)

根据后向传播的链式法则,

LXAout =LXDin XDin XAout \frac{\partial L}{\partial X_{\text {Aout }}}=\frac{\partial L}{\partial X_{\text {Din }}} \frac{\partial X_{\text {Din }}}{\partial X_{\text {Aout }}}

由此可以得到

LXAout =LXDin [1+XDin XCXCXBXBXAout ]\frac{\partial L}{\partial X_{\text {Aout }}}=\frac{\partial L}{\partial X_{\text {Din }}}\left[1+\frac{\partial X_{\text {Din }}}{\partial X_C} \frac{\partial X_C}{\partial X_B} \frac{\partial X_B}{\partial X_{\text {Aout }}}\right]

即使括号中第二项出现梯度消失,也能让该网络的梯度能够维持在第一项的水平。

  • Layer Normalization
    对每层的输出进行归一化,使训练更加稳定。

2.1.3 解码器

深入理解transformer-2024-07-21-21-26-32

  • 掩膜注意力机制
    注意到译码器下方输入的结果,在训练过程中如果把所有的输出都输入到注意力机制中,那么就会导致模型看到未来的信息,这是不合理的。因此,在训练过程中,需要将当前位置之后的信息进行掩膜,使其无法参与计算。
  • 与编码器交互
    编码器生成K,V矩阵,解码器生成Q矩阵。
    深入理解transformer-2024-07-21-21-33-08
    深入理解transformer-2024-07-21-21-33-54

2.2实践

3.总结

编码器可以帮助解码器关注输入中的适当词汇。
解码器将先前输出的列表作为输入,以及包含来自输入的注意力信息,让解码器决定哪个编码器输入是相关的焦点,
为了防止解码器查看未来的标记,采用Mask来掩盖。

4.参考资料

[1] Transformer从零详细解读
[2] 超强动画,一步一步深入浅出解释Transformer原理!
[3] 直观解释注意力机制,Transformer的核心
[4] 动手学深度学习
[5] 【Transformer模型】曼妙动画轻松学,形象比喻贼好记
[6] 原始论文


文章作者: sdj
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 sdj !
  目录