1. 背景
本文将以Qwen2系列大模型为基础,讲解Qwen2模型技术架构及模型原理。
2. 编码
词表的设计可以影响训练的效率和下游任务的表现。Qwen系列模型采用的是tiktoken分词器,这是一种快速分词方法,该方法被使用在OpenAI系列模型中,tiktoen的核心逻辑同样是基于BPE算法,下面介绍下这两类算法。
2.1 BPE
把输入字符串分割为单词或子词(单词的部分)是自然语言处理过程中一项最基本的工作,这个过程是分词,存在较多算法,其中最为经典的是BPE(Byte Pair Encoding)算法。
- 核心代码
- bpe_train
函数定义:
def bpe_train(data: str, vocab_size: int, pat_str: str) -> dict[bytes, int]句子:data = "你好,qwen大模型"词表大小:vocab_size=275词切分正则:pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
执行步骤:
step1:判断vocab_size < 2**8=256?是->报错,否继续。原因见notes1
step2:0-256,先填充到rank词表,此部分词表固定,下面截取部分内容
{b'\x00': 0, b'\x01': 1,...b'A': 65, b'B': 66,..., b'y': 121, b'z': 122,...,b'\xfe': 254, b'\xff': 255}
step3:把data分割为字节列表,即数据变成如下格式为[["你好"],[","],["qwen大模型"]]:
data = [[b'\xe4', b'\xbd', b'\xa0', b'\xe5', b'\xa5', b'\xbd'], [b'\xef', b'\xbc', b'\x8c'], [b'q', b'w', b'e', b'n', b'\xe5', b'\xa4', b'\xa7', b'\xe6', b'\xa8', b'\xa1', b'\xe5', b'\x9e', b'\x8b']]
step4:计算共同出现的字节对,然后把它从255开始往后计数增加到词表中,直到得到的词表rank等于vocab_size
共现字节对:
{
(b'\xe4', b'\xbd'): 1,
(b'\xbd', b'\xa0'): 1,
(b'q', b'w') :1,
(b'w', b'e') :1,
(b'e', b'n') :1,
...
}
更新后的rank词表:
{
b'\x00': 0,
b'\x01': 1,
...,
b'\xe4\xbd':256,
b'\xbd\xa0':257,
...
}
反复迭代,更新最终的rank词表:
{
b'\x00': 0,
b'\x01': 1,
...
b'\xe4\xbd': 256
b'\xe4\xbd\xa0': 257
b'\xe4\xbd\xa0\xe5': 258
b'\xe4\xbd\xa0\xe5\xa5': 259
b'\xe4\xbd\xa0\xe5\xa5\xbd': 260
b'\xef\xbc': 261
b'\xef\xbc\x8c': 262
b'qw': 263
b'qwe': 264
b'qwen': 265
b'qwen\xe5': 266
b'qwen\xe5\xa4': 267
b'qwen\xe5\xa4\xa7': 268
b'qwen\xe5\xa4\xa7\xe6': 269
b'qwen\xe5\xa4\xa7\xe6\xa8': 270
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1': 271
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5': 272
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5\x9e': 273
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5\x9e\x8b': 274
}step4:END
- 核心代码
- bpe_encode
函数定义:
def bpe_encode(mergeable_ranks: dict[bytes, int], input: bytes) -> list[int]:
步骤:
较为简单,不赘述。逻辑和训练极为相似,就是吧input的bytes挨个组成对后直接去查词表,如果有就记下来,最终成该字节对应的编码数字
实验:
```python
>data = "你好,qwen大模型"
>pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
>mergeable_ranks = bpe_train(data=data, vocab_size=275, pat_str=pat_str)
>print(bpe_encode(mergeable_ranks, data.encode("utf-8")))
[260, 262, 274]
结果:
上述代码执行后,"你好,qwen大模型"的编码为[260, 262, 274],找到bpe_train中最终的词表
得到对应字节为:
b'\xe4\xbd\xa0\xe5\xa5\xbd': 260
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5': 272
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5\x9e\x8b': 274
再次解码:
b'\xe4\xbd\xa0\xe5\xa5\xbd'.decode("utf8")=你好
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5'.decode("utf8") =error(编码器未经充分训练,错误忽略)
b'qwen\xe5\xa4\xa7\xe6\xa8\xa1\xe5\x9e\x8b'.decode("utf8")=qwen大模型
解释:
前面的训练只做了非常小的词表,并没有进行过滤和特殊处理,实际的编码比这个要复杂,会对特殊字符进行处理,此部分后面简单介绍。
要点解释
notes1:
- byte=8bit= [0-1][0-1][0-1][0-1][0-1][0-1][0-1][0-1],共计可以表示2**8=256个状态
- BPE编码是以字节为单位,所以,至少要有1个字节,也就是256个状态,每个状态可以认为是词表中某个具体的词或词的编码0-255
- 结论:至少1个字节->词表至少256
- 知识扩充- 众所周知,ASCII表和纯数字存在对应关系,英文属于拼音语言,a-zA-Z共计72个字符即可表示几乎所有文字,1byte就可以描述其基本字符,而中文,是象形文字,其编码较为复杂,往往一个中文汉字,就需要多字节。- UTF-8编码:一个中文汉字需要3个字节```python>zh = "你好">zh.encode("utf8")
b'\xe4\xbd\xa0\xe5\xa5\xbd'>zh[0].encode("utf8")
b'\xe4\xbd\xa0'>zh[1].encode("utf8")
b'\xe5\xa5\xbd'>list(zh.encode("utf8"))
[228, 189, 160, 229, 165, 189]>for v in list(s.encode("utf8")):... bin(v)
'0b11100100'
'0b10111101'
'0b10100000'
'0b11100101'
'0b10100101'
'0b10111101'```- gbk编码:一个中文汉字需要2个字节```python>zh = "你好">zh.encode("gbk")
b'\xc4\xe3\xba\xc3'>zh[0].encode("gbk")
b'\xc4\xe3'>zh[1].encode("gbk")
b'\xba\xc3'>list(zh.encode("gbk"))
[196, 227, 186, 195]>for v in list(s.encode("gbk")):... bin(v)
'0b11000100'
'0b11100011'
'0b10111010'
'0b11000011'```
2.2 tiktoken
前面介绍了BPE算法的核心逻辑,那么tiktoken作为qwen2大模型的编码算法,为什么采用它呢?其实最主要的原因就一个:快。
- 速度快
如上图,可以看到tiktoken相比较huggingface的同类开源的编码器快3-6倍。- Rust编程
tiktoken核心编码的部分,即train和encode都是通过Rust来实现的,Rust具备一些优势。- 性能强,没有编译器,直接就是机器码,其速度和C++差不多,甚至超越C++,比python快2-3倍
- 并行计算,更好的支持并行计算,而不受python等全局锁限制
- 算法优化
BPE算法有多种实现逻辑,前面讲的是最简单的naive算法,其实现有多种,tiktoken就对此部分进行了优化。- BPE核心算法优化
_byte_pair_merge函数实现了BPE的核心逻辑,通过维护一个parts向量来追踪可能合并的字节对。
- BPE核心算法优化
- Rust编程
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {// This is a vector of (start, rank).// The rank is of the pair starting at position start.let mut parts = Vec::with_capacity(piece.len() + 1);// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE// the way we currently do, this is equivalent. An easy way to break this would be to decouple// merge priority from token index or to prevent specific token merges.let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);for i in 0..piece.len() - 1 {let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);if rank < min_rank.0 {min_rank = (rank, i);}parts.push((i, rank));}parts.push((piece.len() - 1, Rank::MAX));parts.push((piece.len(), Rank::MAX));let get_rank = {#[inline(always)]|parts: &Vec<(usize, Rank)>, i: usize| {if (i + 3) < parts.len() {// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted// parts[i + 1], see comment in the main loop.*ranks.get(&piece[parts[i].0..parts[i + 3].0]).unwrap_or(&Rank::MAX)} else {Rank::MAX}}};// If you have n parts and m merges, this does O(mn) work.// We could do something with a heap and do O(m log n) work.// n is often very small so considerations like cache-locality outweigh the algorithmic// complexity downsides of the `parts` vector.while min_rank.0 != Rank::MAX {let i = min_rank.1;// Update parts[i] and parts[i - 1] before removing parts[i + 1], since// `parts.remove(i + 1)` will thrash the cache.if i > 0 {parts[i - 1].1 = get_rank(&parts, i - 1);}parts[i].1 = get_rank(&parts, i);parts.remove(i + 1);min_rank = (Rank::MAX, usize::MAX);for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {if rank < min_rank.0 {min_rank = (rank, i);}}}parts
}
- 特殊字符处理
在_encode_native和_encode_unstable_native函数中,代码区分了普通字符和特殊字符的处理,允许在编码过程中包含用户定义的特殊字符。
函数定义:
_encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize)
核心代码:
step1:通过正则从开始索引匹配特殊字符
next_special = special_regex.find_from_pos(text, start_find).unwrap();
step2:挨个遍历text,找到符合定义的特殊字符跳出,不符合继续找;一旦找到next_special在allowed_special先是跳出,然后直接进行编码
```rust
let mut next_special;
let mut start_find = start;
loop {// Find the next allowed special token, if anynext_special = special_regex.find_from_pos(text, start_find).unwrap();match next_special {Some(m) => {if allowed_special.contains(&text[m.start()..m.end()]) {break;}start_find = m.start() + 1;}None => break,}
}
...
match next_special {Some(m) => {let piece = m.as_str();let token = self.special_tokens_encoder[piece];ret.push(token);start = m.end();last_piece_token_len = 0;}None => break,
}
- 缓存策略
核心代码:使用缓存返回一个正则表达式对象,该对象用于匹配传入的特殊token集合中的任意一个token
@functools.lru_cache(maxsize=128)
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":inner = "|".join(regex.escape(token) for token in tokens)return regex.compile(f"({inner})")
- 哈希表
见前面_byte_pair_merge算法,使用hashmap提高了检索速度 - 支持扩展
tiktoken具备扩展能力