目录
1.Tansformer架构图
2.代码实现
2.1创建类:实现基于位置的前馈网络
2.2创建 残差&LN层标准归一化的类
2.3编码器block
2.4创建编码器
2.5创建解码器
2.6transformer解码器部分
3.知识点个人理解
1.Tansformer架构图
2.代码实现
2.1创建类:实现基于位置的前馈网络
#创建类:实现基于位置的前馈网络
#前馈网络(position wise feed_forward_net)
class PositionWiseFFN(nn.Module):def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs, **kwargs):super().__init__(**kwargs)self.dense1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens) #输入层self.relu = nn.ReLU()self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) #输出层 没有激活函数def forward(self, X):return self.dense2(self.relu(self.dense1(X)))
#测试上面创建的类
#创建前馈网络对象
ffn = PositionWiseFFN(ffn_num_inputs=6, ffn_num_hiddens=4, ffn_num_outputs=8)
ffn.eval()
#假设数据
X = torch.ones((2, 3, 6)) # X的shape=(batch_size, num_steps, num_hiddens)
ffn(X).shape
torch.Size([2, 3, 8])
2.2创建 残差&LN层标准归一化的类
class AddNorm(nn.Module):#残差连接后进行层规范化def __init__(self, normalized_shape, dropout, **kwargs):super().__init__(**kwargs)self.dropout = nn.Dropout(dropout)self.ln = nn.LayerNorm(normalized_shape)def forward(self, X, Y):return self.ln(self.dropout(Y) + X) #残差结构
#使用残差结构,输入的数据的维度必须相同!!1 add_norm = AddNorm([6, 8], dropout=0.2) add_norm.eval() X = torch.ones((2, 6, 8), dtype=torch.float32) Y = torch.ones((2, 6, 8), dtype=torch.float32) add_norm(X, Y)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]]],grad_fn=<NativeLayerNormBackward0>)
2.3编码器block
class EncoderBlock(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, use_bias=False, **kwargs):super().__init__(**kwargs)self.attention = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)self.addnorm1 = dltools.AddNorm(norm_shape, dropout)self.ffn = dltools.PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs)self.addnorm2 = dltools.AddNorm(norm_shape, dropout)def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))
#假设X数据
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 24, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape
torch.Size([2, 100, 24])
2.4创建编码器
# 编码器
class TransformerEncoder(dltools.Encoder):def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout, use_bias=False, **kwargs):super().__init__(**kwargs)self.num_hiddens = num_hiddensself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_encoding = dltools.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module('block' + str(i), EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,num_heads, dropout, use_bias))def forward(self, X, valid_lens, *args):# 对embedding之后的数据进行缩放, 有助于收敛X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))self.attention_weights = [None] * len(self.blks)for i, blk in enumerate(self.blks):X = blk(X, valid_lens)self.attention_weights[i] = blk.attention.attention.attention_weightsreturn X
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 24, 8, 2, 0.3)
encoder.eval()
X = torch.ones((2, 100), dtype=torch.long)
encoder(X, valid_lens).shape
torch.Size([2, 100, 24])
2.5创建解码器
#decoder block
class DecoderBlock(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, i, **kwargs):super().__init__(**kwargs)self.i = i#第一层多头注意力机制self.attention1 = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)#残差与LN层归一化self.addnorm1 = dltools.AddNorm(norm_shape, dropout)#第二层多头注意力机制self.attention2 = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)#残差与LN层归一化self.addnorm2 = dltools.AddNorm(norm_shape, dropout)#ffn层self.ffn = dltools.PositionWiseFFN(ffn_num_input, ffn_num_hiddens, ffn_num_outputs)self.addnorm3 = dltools.AddNorm(norm_shape, dropout)def forward(self, X, state):#state包含三个东西:enc_outputs, enc_valid_lens,dec_outputs enc_outputs, enc_valid_lens = state[0], state[1]#state[2][self.i]表示当前的dec_outputs#若当前的dec_outputs为空, 表示处于训练模式:即X是需要传入的强制学习的目标标记(文本序列)if state[2][self.i] is None:key_values = Xelse: #若不为空,就是处于预测模式 (训练好了,state[2][self.i]就会存当前block的输出储数据)#预测需要把前面时刻预测得到的信息和当前的block的输出得到的信息拼到一起#此时的X表示前面时刻预测得到的信息key_values = torch.cat((state[2][self.i], X), axis=1)state[2][self.i] = key_values #存储当前block的输出得到的信息,赋值给变量#在训练时, 需要对真实值进行遮蔽if self.training: #若处于训练#X是三维数据, shape= (batch_size, num_steps, num_hiddens)batch_size, num_steps, _ = X.shape#dec_valid_lens的shape=(batch_size, num_steps),每一行数据都是[1, 2, ....., num_steps]dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)else:dec_valid_lens = None #为空,表示不训练,处于预测中,预测不需要遮蔽序列元素#多头自注意力层X2 = self.attention1(X, key_values, key_values, dec_valid_lens)#残差+LN层归一化Y = self.addnorm1(X, X2)#多头自注意力层Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)Z = self.addnorm2(Y, Y2)return self.addnorm3(Z, self.ffn(Z)), state
valid_lens
tensor([3, 2])
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 24, 24, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
result1, result2 = decoder_blk(X, state)
result1.shape, result2[2][0].shape
都是
torch.Size([2, 100, 24])
2.6transformer解码器部分
# transformer解码器部分
class TransformerDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout, **kwargs):super().__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_embedding = dltools.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module('block' + str(i), DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, i))self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):return [enc_outputs, enc_valid_lens, [None]*self.num_layers]def forward(self, X, state):X = self.pos_embedding(self.embedding(X) * math.sqrt(self.num_hiddens))self._attention_weights = [[None] * len(self.blks) for _ in range(2)]for i, blk in enumerate(self.blks):X, state = blk(X, state)self._attention_weights[0][i] = blk.attention1.attention.attention_weightsself._attention_weights[1][i] = blk.attention2.attention.attention_weightsreturn self.dense(X), state@propertydef attention_weights(self):return self._attention_weights
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()
ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads = 32, 64,32, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
# 开始预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000 i lost . => ("j'ai perdu .", []), bleu 1.000 he's calm . => ('il est calme .', []), bleu 1.000 i'm home . => ('je suis chez moi .', []), bleu 1.000
3.知识点个人理解