目录
1.导包
2.读取本地数据
3.定义函数:数据预处理
4.定义函数:词元化
5.统计每句话的长度的分布情况
6. 获取词汇表
7. 截断或者填充文本序列
8.将机器翻译的文本序列转换成小批量tensor
9.加载数据
10.知识点个人理解
1.导包
#导包
import os
import torch
import dltools
2.读取本地数据
#读取本地数据
with open('./fra-eng/fra.txt', 'r', encoding='utf-8') as f:raw_text = f.read() #一次读取所有数据print(raw_text[:75])
Go. Va ! Hi. Salut ! Run! Cours ! Run! Courez ! Who? Qui ? Wow! Ça alors !
3.定义函数:数据预处理
#数据预处理
def preprocess_nmt(text):#判断标点符号前面是否有空格def no_space(char, prev_char):return char in set(',.!?') and prev_char != ' '#替换识别不了的字符,替换不正常显示的空格,将大写字母变成小写text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()#在单词和标点之间插入空格out = [' '+ char if i>0 and no_space(char, text[i-1]) else char for i, char in enumerate(text)]return ''.join(out) #合并out#测试:数据预处理
text = preprocess_nmt(raw_text)
print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ça alors !
4.定义函数:词元化
#定义函数:词元化
def tokenize_nmt(text, num_examples=None):"""text:传入的数据文本num_examples=None:样本数量为空,判断数据集中剩余的数据量是否满足一批所取的数据量"""source, target = [], []#以换行符号\n划分每一行for i, line in enumerate(text.split('\n')):#if num_examples 表示不是空,相当于 if num_examples != Noneif num_examples and i > num_examples:break#从每一行数据中 以空格键tab分割数据parts = line.split('\t') #将英文与对应的法语分割开if len(parts) == 2: #单词文本与标点符号两个元素source.append(parts[0].split(' ')) #用空格分割开单词文本与标点符号两个元素target.append(parts[1].split(' '))return source, target#测试词元化代码
source, target = tokenize_nmt(text)
source[:6], target[:6]
([['go', '.'],['hi', '.'],['run', '!'],['run', '!'],['who', '?'],['wow', '!']],[['va', '!'],['salut', '!'],['cours', '!'],['courez', '!'],['qui', '?'],['ça', 'alors', '!']])
5.统计每句话的长度的分布情况
#统计每句话的长度的分布情况
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):dltools.set_figsize() #创建一个适当的画布_,_,patches = dltools.plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])dltools.plt.xlabel(xlabel) #添加x标签dltools.plt.ylabel(ylabel) #添加y标签for patch in patches[1].patches: #为patches[1]的柱体添加斜线patch.set_hatch('/')dltools.plt.legend(legend) #添加标注#测试代码:统计每句话的长度的分布情况
show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', source, target)
6. 获取词汇表
#获取词汇表
src_vocab = dltools.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])
len(src_vocab)
10012
7. 截断或者填充文本序列
def truncate_pad(line, num_steps, padding_token):"""line:传入的数据num_steps:子序列长度padding_token:需要填充的词元"""if len(line) > num_steps:return line[:num_steps] #太长就截断#太短就补充return line + [padding_token] * (num_steps - len(line)) #填充#测试
#source[0]表示英文单词
truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>'])
[47, 4, 1, 1, 1, 1, 1, 1, 1, 1]
8.将机器翻译的文本序列转换成小批量tensor
def build_array_nmt(lines, vocab, num_steps):#通过vocab拿到line的索引lines = [vocab[l] for l in lines]#每个序列结束之后+一个'eos'lines = [l + [vocab['eos']] for l in lines]#对每一行文本 截断或者填充文本序列,再转化为tensorarray = torch.tensor([truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])#获取有效长度valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)return array, valid_len
9.加载数据
def load_data_nmt(batch_size, num_steps, num_examples=600):# 需要返回数据集的迭代器和词表text = preprocess_nmt(raw_text)source, target = tokenize_nmt(text, num_examples)src_vocab = dltools.Vocab(source, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])tgt_vocab = dltools.Vocab(target, min_freq=2, reserved_tokens=['<pad>', '<bos>', '<eos>'])src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)data_iter = dltools.load_array(data_arrays, batch_size)return data_iter, src_vocab, tgt_vocab
#测试代码
train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)for X, X_valid_len, Y, Y_valid_len in train_iter:print('X:', X.type(torch.int32))print('X的有效长度:', X_valid_len)print('Y:', Y.type(torch.int32))print('Y的有效长度:',Y_valid_len)break
X: tensor([[17, 20, 4, 0, 1, 1, 1, 1],[ 7, 84, 4, 0, 1, 1, 1, 1]], dtype=torch.int32) X的有效长度: tensor([4, 4]) Y: tensor([[ 11, 61, 144, 4, 0, 1, 1, 1],[ 6, 33, 17, 4, 0, 1, 1, 1]], dtype=torch.int32) Y的有效长度: tensor([5, 5])
10.知识点个人理解