跳到主要内容

Transformer模型

由 Vaswani 等人(2017)提出的 Transformer 模型是一种完全基于注意力机制的深度架构,省略了传统的卷积层或循环层。它专为序列到序列学习而设计,并已广泛应用于语言、视觉、语音和强化学习领域。该架构支持并行计算,并具有输入和输出之间较短的路径长度,这使其在处理序列数据任务时效率极高。

模型组件

多头自注意力机制

  • 查询(Queries)、键(keys)和值(values)都是从相同的输入中导出,用于编码器和解码器中的自注意力机制,从而增强模型从相同输入中捕获信息不同方面的能力。
  • 编码器使用自注意力机制进行输入表示,而解码器使用掩码自注意力机制以确保输出仅基于较早的时间步,从而保持自回归特性。

位置编码

  • 添加有关序列中词元位置的信息,弥补模型中缺乏循环的不足。
  • 使用不同频率的正弦和余弦函数。

编码器和解码器层

  • 编码器和解码器均由堆叠层组成,每层包含两个子层:一个多头自注意力机制和一个位置级全连接前馈网络。
  • 在单个层中的两个子层周围,均采用残差连接,后接层归一化。

编码器-解码器注意力机制

  • 在解码器中,一个额外的编码器-解码器注意力层有助于解码器关注输入序列的适当部分。
  • 查询(Queries)来自前一个解码器层,而键(keys)和值(values)来自编码器的输出。

位置级前馈网络

  • 对每个位置独立且相同地应用一个全连接前馈网络。这包含两个线性变换,中间夹一个ReLU激活函数。

残差连接和层归一化

  • 编码器和解码器中的每个子层,将其输出添加到子层的输入(残差连接),并进行归一化(层归一化)。

在机器翻译中的应用

  • Transformer 模型因其高效处理序列的能力,在机器翻译中取得了显著成功。
  • 通常使用源语句和目标语句的配对数据集进行训练。

数学基础

  • 注意力函数(缩放点积注意力):
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中 Q,K,VQ, K, V 分别是查询(Queries)、键(keys)和值(values),dkd_k 是键的维度。

  • 每个子层输出,包括嵌入和位置编码,都会进行缩放或归一化,以促进稳定的训练动态,特别是在更深的模型中。

挑战与创新

  • 由于自注意力机制相对于序列长度具有二次复杂度,处理长输入序列的计算成本可能很高。
  • “高效Transformer”等变体和改进通过近似注意力机制或稀疏化连接来解决此问题。

示例

以下是 PyTorch 示例中实现 Transformer 模型的完整代码块,侧重于编码器和解码器组件。这包括设置、模型创建、训练和评估阶段。

import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l

class PositionWiseFFN(nn.Module):
"""位置级前馈网络。"""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
self. relu = nn.ReLU ()
self. dense2 = nn.Linear (ffn_num_outputs, ffn_num_hiddens)

def forward (self, X):
return self.dense2 (self.relu (self.dense1 (X)))

class AddNorm (nn. Module):
"""残差连接后接层归一化。"""
def __init__(self, normalized_shape, dropout):
super ().__init__()
self. dropout = nn.Dropout (dropout)
self. ln = nn.LayerNorm (normalized_shape)

def forward (self, X, Y):
return self.ln (self.dropout (Y) + X)

class MultiHeadAttention (nn. Module):
"""多头注意力。"""
# `num_hiddens`: d_v, `num_heads`: h
def __init__(self, num_hiddens, num_heads, dropout, bias=False):
super ().__init__()
self. num_heads = num_heads
self. attention = d2l.DotProductAttention (dropout)
self. W_q = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_k = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_v = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_o = nn.Linear (num_hiddens, num_hiddens, bias=bias)

def forward (self, queries, keys, values, valid_lens):
# `queries`、`keys` 或 `values` 的形状:
# (批量大小, 查询或键-值对数量, 隐藏单元数)
# `valid_lens` 的形状:
# (批量大小,) 或 (批量大小, 查询数量)
# 转置后,输出 `queries`、`keys` 或 `values` 的形状:
# (批量大小 * 头数, 查询或键-值对数量, 隐藏单元数 / 头数)
queries = self. W_q (queries). reshape (queries. shape[0], queries. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
keys = self. W_k (keys). reshape (keys. shape[0], keys. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
values = self. W_v (values). reshape (values. shape[0], values. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
if valid_lens is not None:
valid_lens = torch. repeat_interleave (valid_lens, repeats=self. num_heads, dim=0)

output = self.attention (queries, keys, values, valid_lens)
output = output.permute (0, 2, 1, 3). reshape (output. shape[0], output. shape[1], -1)
return self. W_o (output)

class TransformerEncoderBlock (nn. Module):
"""Transformer 编码器块。"""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias=False):
super ().__init__()
self. attention = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm1 = AddNorm (num_hiddens, dropout)
self. ffn = PositionWiseFFN (num_hiddens, ffn_num_hiddens)
self. addnorm2 = AddNorm (num_hiddens, dropout)

def forward (self, X, valid_lens):
Y = self.addnorm1 (X, self.attention (X, X, X, valid_lens))
return self.addnorm2 (Y, self.ffn (Y))

class TransformerEncoder (nn. Module):
"""Transformer 编码器。"""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout):
super ().__init__()
self. embedding = nn.Embedding (vocab_size, num_hiddens)
self. pos_encoding = d2l.PositionalEncoding (num_hiddens, dropout)
self. blks = nn.ModuleList ([
TransformerEncoderBlock (num_hiddens, ffn_num_hiddens, num_heads, dropout) for _ in range (num_layers)


])

def forward (self, X, valid_lens, *args):
X = self. pos_encoding (self.embedding (X) * math.sqrt (num_hiddens))
for blk in self. blks:
X = blk (X, valid_lens)
return X

class TransformerDecoderBlock (nn. Module):
"""Transformer 解码器块。"""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i, use_bias=False):
super ().__init__()
self. i = i
self. attention1 = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm1 = AddNorm (num_hiddens, dropout)
self. attention2 = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm2 = AddNorm (num_hiddens, dropout)
self. ffn = PositionWiseFFN (num_hiddens, ffn_num_hiddens)
self. addnorm3 = AddNorm (num_hiddens, dropout)

def forward (self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self. i] is None:
key_values = X
else:
key_values = torch.cat ((state[2][self. i], X), dim=1)
state[2][self. i] = key_values

if self. training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange (1, num_steps + 1, device=X.device). repeat (batch_size, 1)
else:
dec_valid_lens = None

X2 = self.attention1 (X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1 (X, X2)
Y2 = self.attention2 (Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2 (Y, Y2)
return self.addnorm3 (Z, self.ffn (Z)), state

class TransformerDecoder (nn. Module):
"""Transformer 解码器。"""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout):
super ().__init__()
self. num_hiddens = num_hiddens
self. num_layers = num_layers
self. embedding = nn.Embedding (vocab_size, num_hiddens)
self. pos_encoding = d2l.PositionalEncoding (num_hiddens, dropout)
self. blks = nn.ModuleList ()
for i in range (num_layers):
self.blks.append (TransformerDecoderBlock (num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self. dense = nn.Linear (num_hiddens, vocab_size)

def init_state (self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self. num_layers]

def forward (self, X, state):
X = self. pos_encoding (self.embedding (X) * math.sqrt (self. num_hiddens))
for blk in self. blks:
X, state = blk (X, state)
return self.dense (X), state

num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
ffn_num_hiddens, num_heads = 64, 4
key_size, query_size, value_size = num_hiddens, num_hiddens, num_hiddens
norm_shape = [num_hiddens]
train_iter, src_vocab, tgt_vocab = d2l. load_data_nmt (batch_size, num_steps)

encoder = TransformerEncoder (
len (src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder (
len (tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout)
model = d2l.EncoderDecoder (encoder, decoder)
d2l. train_seq2seq (model, train_iter, lr=0.005, num_epochs=50, tgt_vocab=tgt_vocab)

参考资料和有用链接