【leetcode】最长公共子路径问题
滚动hash
滚动哈希(rolling hash)也叫 Rabin-Karp 字符串哈希算法,它是将某个字符串看成某个进制下的整数,并将其对应的十进制整数作为hash值。
滚动hash算法的推导
假设有一个长度为n的数组a[0],a[1],a[2],…a[n-1],数组中的最大值为ma, 我们选取进制k满足k>ma,将数组a看成是n位k进制整数,那么其对应的10进制整数为:
∑ i = 0 n − 1 a [ i ] ∗ k n − 1 − i \sum_{i=0}^{n-1} a[i] * k^{n-1-i} i=0∑n−1a[i]∗kn−1−i
这样一来,在子数组长度固定的前提下,给定进制 k,子数组与其十进制值满足「一一对应」的关系,即不会有两个不同的子数组,它们的十进制值相同。因此滚动哈希得到的哈希值是可以表示原子数组的。
滚动哈希的一大优势在于,如果我们需要求出一个数组中长度为 len 的所有子数组的哈希值,需要的时间仅为线性,即如果我们已经计算出数组中以 j 开始的子数组的哈希值:
h a s h ( j ) = ∑ i = 0 l e n − 1 a [ j + i ] ∗ k l e n − 1 − i hash(j) = \sum_{i=0}^{len-1} a[j+i] * k^{len-1-i} hash(j)=i=0∑len−1a[j+i]∗klen−1−i
那么要计算以 j+1 开始的子数组的哈希值,我们通过公式推导:
h a s h ( j + 1 ) = ∑ i = 0 l e n − 1 a [ j + 1 + i ] ∗ k l e n − 1 − i = ∑ i = 1 l e n a [ j + i ] ∗ k l e n − i = k ( ∑ i = 1 l e n a [ j + i ] ∗ k l e n − 1 − i ) = k ( h a s h ( j ) − a [ j ] ∗ k l e n − 1 + a [ j + l e n ] ∗ k − 1 ) = k ∗ h a s h ( j ) − a [ j ] ∗ k l e n + a [ j + l e n ] \begin{aligned} hash(j+1) &= \sum_{i=0}^{len-1} a[j+1+i] * k^{len-1-i} \\ &= \sum_{i=1}^{len} a[j+i]*k^{len-i} \\ &= k(\sum_{i=1}^{len} a[j+i]*k^{len-1-i}) \\ &= k(hash(j) - a[j]*k^{len-1} + a[j+len]*k^{-1}) \\ &= k*hash(j) - a[j]*k^{len} + a[j+len] \end{aligned} hash(j+1)=i=0∑len−1a[j+1+i]∗klen−1−i=i=1∑lena[j+i]∗klen−i=k(i=1∑lena[j+i]∗klen−1−i)=k(hash(j)−a[j]∗klen−1+a[j+len]∗k−1)=k∗hash(j)−a[j]∗klen+a[j+len]
就可以在 ϕ ( 1 ) \phi(1) ϕ(1)的时间内得到该值。
利用滚动hash算法计算最长公共子路径的代码示例如下:
上述代码的执行效率较低,以下代码通过二分法优化,可以有效降低代码的时间复杂度:
def longest_common_subpath_2(n: int, paths: List[List[int]]) -> int:mod = (10 ** 9 + 7) * (10 ** 9 + 9)base = 10 ** 6 + 3# get min len of pathsmin_len = len(min(paths, key=lambda x: len(x)))def check(x: int) -> bool:k = pow(base, x, mod)hash_values = defaultdict(int)for path in paths:cnt = Counter()hash_value = 0for i in range(x):hash_value = (hash_value * base + path[i]) % modcnt[hash_value] += 1hash_values[hash_value] += 1for i in range(x, len(path)):hash_value = (hash_value * base + path[i] - path[i - x] * k) % modif hash_value not in cnt:cnt[hash_value] += 1hash_values[hash_value] += 1return max(hash_values.values(), default=0) == len(paths)l, r, ans = 1, min_len, 0while l <= r:mid = (l + r) >> 1if check(mid):ans = midl = mid + 1else:r = mid - 1return ans