目的
基于transformers查看字典树是怎么生成的呢,输入字符串text是怎么在字典树中进行分割的,一起来看一下
参考链接
wikipedia
代码
def traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove):# trie_pointer存在最后终止符,需要重置和存储结果在offsets中# lookahead要匹配最长的# 比如extra_id_1 vs extra_id_100这种,先找到extra_id_1,还是会继续匹配的# "[CLS]", "L" 匹配CLSfor lookstart, looktrie_pointer in states.items():# if lookstart in to_remove:# 额外处理一下的,处理那个bug# continueif lookstart > start:# This partial match is later, we can stop lookingbreak# 这个匹配是后面的结果,可以停止看了"""trie = Trie()trie.add("喜欢你")trie.add("欢你")trie.split("我喜欢你一起玩")# ['我', '喜欢你', '一起玩']# 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1}), (2, {'': 1})])# start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true# 开始重新遍历states# lookstart=2时,lookstart>start,这里对应的token "欢你" 已经是后面的匹配了,停止"""elif lookstart < start:# This partial match is earlier, the trie pointer# was already updated, so index is + 1# 这是前面的匹配,trie_pointer已经更新了,所以index=当前+1lookahead_index = current + 1end = current + 1"""trie = Trie()trie.add("喜欢你")trie.add("欢")trie.split("我喜欢你一起玩")# ['我', '喜欢你', '一起玩']# 当匹配到char“你”的时候,states=OrderedDict([(1, {'你': {'': 1}}), (2, {'': 1})])# start=2时,trie_pointer = {'': 1} 则'' in trie_pointer为true# 开始重新遍历states# lookstart=1时,lookstart<start# 虽然是在"欢"时找到了对应的终止符,但是前面的匹配"喜欢"也不能不管了,所以要看一下下面的char是否匹配"""else:# Here lookstart == start and# looktrie_pointer == trie_pointer# It wasn't updated yet so indices are current oneslookahead_index = current# 当前token对应的位置end = current# """trie = Trie()trie.add("喜欢你")trie.split("我喜欢你一起玩")# ['我', '喜欢你', '一起玩']# 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1})])# start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true# 开始重新遍历states# lookstart=1时,lookstart=start# 当前token并未验证是否在trie_pointer中,所以需要记录当前char对应的位置"""next_char = text[lookahead_index] if lookahead_index < len(text) else None# 下一个字符,避免extra_id_1 vs extra_id_100这种情况if "" in looktrie_pointer:# 找到终止符了,对应了lookstart == start这种情况start = lookstartend = lookahead_indexskip = lookahead_indexwhile next_char in looktrie_pointer:# 对应lookstart<start这种情况,要把后面的token相同的都找到# 这里start、end值都是有了,可能不更新,也可能更新多轮# 这里我没仔细找例子,先把整个流程跑通,重新找demolooktrie_pointer = looktrie_pointer[next_char]lookahead_index += 1# 后面的char位置if "" in looktrie_pointer:# 找到终止符了start = lookstartend = lookahead_indexskip = lookahead_indexif lookahead_index == len(text):# 我咋没把这种情况测试出来呢# End of stringprint ('End of string')breaknext_char = text[lookahead_index]# End lookahead# Storing and resettingoffsets.append(start)offsets.append(end)reset = Truereturn offsets, skip, reset, to_removedef traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove):for start, trie_pointer in states.items():# 遍历statesif "" in trie_pointer:# 这个指针到结束位置了offsets, skip, reset, to_remove = traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove)breakelif current_char in trie_pointer:# 当前字符在trie_pointer中,则更新指针,将states更新trie_pointer = trie_pointer[current_char]# 喜 欢 和 你states[start] = trie_pointer# 慢慢的找到这个单词的剩余字符else:# 当前字符不能匹配到trie_pointer,需要停止追踪这个匹配# 因为python迭代器的工作方式,不能直接在这个循环中执行这个操作to_remove.add(start)# 因为这里的原因,所以会造成['ab c', 'd']的情况return offsets, skip, reset, to_removeimport bisect
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overloadclass Trie:"""Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one passLoose reference https://en.wikipedia.org/wiki/Trie"""def __init__(self, *args):self.data = {}self._tokens = set()self._termination_char = ""self.update(*args)def update(self, *args):"""Updates the Trie with new tokens provided as arguments.新单词作为参数更新Triedemo:trie = Trie(("hello","you"))trie.data# {'h': {'e': {'l': {'l': {'o': {'': 1}}}}}, 'y': {'o': {'u': {'': 1}}}}trie._tokens# {'hello', 'you'}Args:*args: Variable number of words to be added to the Trie.添加到Trie的单词"""for token in tuple(*args):self.add(token)def add(self, word: str):"""Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.The special key `""` in `self._termination_char` is used to represent termination.This function is idempotent, adding twice the same word will leave the trie unchangedExample:```python>>> trie = Trie()>>> trie.add("Hello 友達")>>> trie.data{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}>>> trie.add("Hello")>>> trie.data{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}```"""if not word:# ''或者为None# Prevent empty stringreturnself._tokens.add(word)# 要把word添加到self._tokens中ref = self.datafor char in word:ref[char] = ref.setdefault(char, {})# 这一步会更新self.dataref = ref[char]# 这一步不会更新self.dataref[self._termination_char] = 1def split(self, text: str) -> List[str]:"""Will look for the words added to the trie within `text`. Output is the original string splitted along theboundaries of the words found.This trie will match the longest possible word first !Example:```python>>> trie = Trie()>>> trie.split("[CLS] This is a extra_id_100")["[CLS] This is a extra_id_100"]>>> trie.add("[CLS]")>>> trie.add("extra_id_1")>>> trie.add("extra_id_100")>>> trie.split("[CLS] This is a extra_id_100")["[CLS]", " This is a ", "extra_id_100"]```"""# indexes are counted left of the chars index.# "hello", index 0, is left of h, index 1 is between h and e.# index 5 is right of the "o".# States are going to capture every possible start (indexes as above)# as keys, and have as values, a pointer to the position in the trie# where we're at. This is a partial match for now.# This enables to keep track of multiple matches while we're iterating# the string# If the trie contains, "blowing", and "lower" and we encounter the# string "blower", we need to split into ["b", "lower"].# This is where we need to keep track of multiple possible starts.# indexes是字符index的左计数states = OrderedDict()# offsets包含了每一个需要分割的分块,强制在位置0和位置len(text)进行分割offsets = [0]# This is used by the lookahead which needs to skip over# some text where the full match exceeded the place in the initial# for loop# 这是由lookahead使用的,它需要跳过一些完全匹配超出初始 for 循环中位置的文本skip = 0# Main loop, Giving this algorithm O(n) complexity# 主循环,O(n) 复杂的# current是当前位置,current_char是当前字符for current, current_char in enumerate(text):if skip and current < skip:continueto_remove = set()# 将停止匹配的状态,都放到to_remove,停止追踪reset = False# 当找到一个匹配,就丢掉一切,这是一个贪心算法,会匹配到第一个发现的tokenoffsets, skip,reset, to_remove = traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove)# 遍历states,这一步会遍历每个statesif reset:# 找到匹配时,就丢掉一切,即重置statesstates = {}else:# 没有找到匹配,将停止匹配的状态从states中删除,停止追踪for start in to_remove:del states[start]# If this character is a starting character within the trie# start keeping track of this partial match.# 当前字符在self.data中,要开始追踪了# skip代表了最后一个找到的char对应的位置,因为在off_1中也会遍历char位置,所以设置了这样一个位置指针# 贪心算法,前面token已经把这个char使用上了,后面即使有一样的,也不会管了"""trie = Trie()trie.add("喜欢你一")trie.add("一起玩")trie.split("我喜欢你一起玩")# ['我', '喜欢你一', '起玩']"""# self.data = {'喜': {'欢': {'和': {'你': {'': 1}}}, '爱': {'': 1}}, '欢': {'喜': {'': 1}}}# '喜' in self.data == Trueif current >= skip and current_char in self.data:states[current] = self.data[current_char]# 根据current_char在self.data中找到对应的value,即剩余的字符串 # We have a cut at the end with states.for start, trie_pointer in states.items():if "" in trie_pointer:# This is a final match, we need to reset and# store the results in `offsets`.# 最后的匹配了,将结果保存在offsets中end = len(text)offsets.append(start)offsets.append(end)# Longest cut is always the one with lower start so the first# item so we need to break.break# 这里没有看懂return self.cut_text(text, offsets)def cut_text(self, text, offsets):# We have all the offsets now, we just need to do the actual splitting.# We need to eventually add the first part of the string and the eventual# last part.offsets.append(len(text))# 把最开始和最后的部分都要加上tokens = []start = 0for end in offsets:if start > end:print ("There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"" anyway.")continueelif start == end:# This might happen if there's a match at index 0# we're also preventing zero-width cuts in case of two# consecutive matchescontinuetokens.append(text[start:end])start = endreturn tokens
测试用例
新建
trie = Trie(("sea","see"))
trie._tokens# {'sea', 'see'}
trie.data# {'s': {'e': {'a': {'': 1}, 'e': {'': 1}}}}
# ->s->e->a
# ->s->e->e
添加
trie = Trie()
trie.add("Hello 友達")
trie.data
分割
trie = Trie()
trie.split("[CLS] This is a extra_id_100")
# ["[CLS] This is a extra_id_100"]trie = Trie()
trie.add("喜欢你")
trie.split("我喜欢你一起玩")trie = Trie()
trie.add("喜欢你一")
trie.add("一起玩")
trie.add("你一起玩")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你一', '起玩']trie = Trie()
trie.add("abc")
trie.add("b")
trie.split("ab cd")
# ['ab c', 'd']
# 这里有个bug,对应了traversal_states_second中的这里# if lookstart in to_remove:# 额外处理一下的,处理那个bug# continue
额外说几句
查看这个代码的过程中,发现了一个bug,在git上提交了,还没有得到回复,可能是我自己想错了。
这里的代码逻辑比较散,需要对应多种特殊情况,希望能画出流程图来。