KMP算法

·   ·   ·   ·

  ·   ·


KMP算法 (Knute-Morris-Pratt Algorithm) 是一种字符串匹配算法,由D.E.Knuth,J.H.Morris和V.R.Pratt提出,是一种可以在文本串 $s$ 中快速查找模式串 $p$ 的一种算法

暴力查找

要想知道KMP算法是如何减少字符串查找的复杂度,让我们先看基础的暴力方法是如何大量进行重复操作的。所谓暴力查找,就是逐个字符的进行匹配(比较 $s[i]$ 和 $p[j]$),如果当前字符匹配成功,就继续判断下一个字符
($ ++i,++j$);如果匹配失败,就将 $i$ 回溯,$j$ 改为 0($ i = i - j + 1,j = 0$)。代码如下:

// 暴力匹配
int i = 0, j = 0;
while(i < s.length())
{
    if(s[i] == p[i])
        ++i, ++j;
    else
        i = i - j + 1, j = 0;
    if(j == p.length())
    {
        cout << i - j << endl;
        i = i - j + 1;
        j = 0;
    }
}

假如我们有字符串 s = "acacaba",模式串 p = "acaba",进行暴力匹配的过程就是如下:

从头开始匹配,第一位都是 $a$,匹配成功

第 2~3 个字符也匹配成功,继续下一步

下一位,匹配失败,回溯($ i = 3 - 3 + 1 = 1,j = 0$)

匹配失败,继续尝试

下一位,匹配成功

直到结尾都匹配成功

设两个字符串的长度分别为 $n$,$m$。则暴力方法的时间复杂度最坏为 $O(nm)$,因为 $i$ 的回溯花去了太多时间,但如果不进行回溯,又将 $i$ 置为 0 ,很可能会缺漏

为了能让 $j$ 每次被赋予一个合适的值,我们引入 PMT(Partical Match Table,部分匹配表)

部分匹配表 PMT

$j$ 应该被赋予的值,只和模式串自身有关。每个模式串都对应一张 PMT,比如,"acacaba" 对应的 PMT 如下:

0

1

2

3

4

5

6

p

a

c

a

c

a

b

a

PMT

0

0

1

2

3

0

1

简单来说,PMT$[i]$ 就是从 $p[0]$ 开始往后数、同时从 $p[i]$ 往前数的相同的位数,也就是 $p$ 中真前缀和真后缀的交集最长能有多少位(必然少于 $p$ 的长度)

为什么 PMT 能用来确定 $j$ 指针的位置呢?让我们回到上面暴力算法第一次失去匹配时候的情形:

这时,虽然 'c''b' 没有匹配上,但我们可以保持 $i$ 指针不变,而将 $j$ 指针左移。因为我们注意到,"aca" 已经匹配成功了,它拥有一个前缀 "a",以及一个后缀 "a" ,所以我们可以将画线部分利用起来,变成下面这样:

实际上我们这时候正是令 j = PMT[j - 1],我们再看下面的这个例子:

发生失配,我们令 j = PMT[j - 1] (=3)

这次仍然不匹配,我们继续进行操作:

这次成功匹配,当然,我们并不是总是可以成功进行匹配,有可能 $j$ 指针一路减到 0 的时候,$s[i]$ 仍然不等于 $p[j]$,这时候我们不再移动 $j$ 指针

用代码实现以上过程:

for(int i = 0,j = 0; i < s.length(); ++i)
{
    while(j && s[i] != p[j]) // 当不匹配时,将 j 指针的位置改为 pmt[j - 1]
        j = pmt[j - 1];

    if(s[i] == p[i]) // 匹配时自增
        ++j;

    if(j == p.length()) // 如果走到了模式串 p 的最后一位,证明匹配成功
    {
        // some operations
        j = pmt[j - 1];
    } 
}

在许多文章中也会使用到 next 数组,即将 PMT 数组整体向右移一位(特别的,令 next[0] = -1),表示在那一位失配时应跳转到的索引。也就是按照 i -> next[i] -> next[next[i]] -> ... 的顺序跳转,原理和实现其实相差不多

计算PMT

所以,我们还要解决最后的一个问题,如何求出 PMT ?如果用暴力方法直接进行求解,时间复杂度是 $O(m^{2})$
一种简单且优雅的做法是,让模式串 $p$ 在 错开一位 后,自己和自己进行匹配(也就是用前缀去匹配后缀)
由于我们易得 pmt[0] = 0,而之后的每一位都将在匹配过程中记录 $j$ 得到

我们以模式串 ”ababcabaa” 为例

匹配失败,所以 pmt[1] = -1 + 1 = 0i 指针后移

接下来匹配成功,j 指针右移,可知 pmt[2] = 1,然后将两个指针都右移

依然匹配成功,$j$ 指针右移,pmt[3] = 2

下一位失配,因为前面我们已经算出来 pmt[2 - 1] 的值,所以我们也可以像匹配字符串的时候一样使用
pmt[2 - 1] = pmt[1] = 0,所以退回开头的位置

$j$ 指针已经到了开头,仍未匹配成功,所以不再移动,pmt[4] = j = 0

接下来也按这种方法操作:

在最后一位的时候失配,这次我们先令 j = pmt[j - 1] = 1

再次进行匹配,匹配成功,pmt[i] = j = 1

自此,我们通过一趟自我匹配,求出了 PMT,代码如下:

// pmt[0] = 0;
for(int i = 1, j = 0; i < p.length(); ++i)
{
    while(j && p[i] != p[j])
        j = pmt[i];
    if(p[i] == p[j])
        j++;
    pmt[i] = j;
}

Knuth常数优化

以上的算法被称为 MP 算法,在一般的题目中都足以使用,而 KMP 算法还有一个由 Knuth 提出的常数优化,不过一般不太用得上,在这也介绍一下

我们可以看出,中间的几次跳转毫无意义,我们明知道 da 是不能匹配的,但还是做了三次无用的操作。我们可以在计算 pmt 的时候做一些小改动来避免这样的情况

在上图这种情况下,我们匹配到这一步的时候应该令 pmt[i] = ++j = 2,但是可以发现,p[i + 1]p[j + 1] 是同样的字符 'a' 。也就是说,在稍后进行匹配的时候,如果指针 $j = 4$ 的时候失配( "ababa" 无法匹配),那么在指针 $j = 2$ 的时候也会失配( "aba" 也无法匹配,因为跳转过去后,还是使用 ‘a’ 做匹配)。

所以我们可以直接将路径压缩,让 pmt[i] = pmt[j] (pmt[2 - 1]) ,而不是 ++j,从而直接跳过指针 $j = 2$ 的情况

不过这样求出的数组已经不符合 PMT 的性质,所以我在实现中使用 nextval 代替

不论是 MP 算法还是 KMP 算法,其时间复杂度都为 $O(n + m)$,这是因为在这两个算法中,无论是 ++i 还是 ++j 操作,最多都只进行了 $n + m$ 次,虽然 $j$ 在该过程中有所减小,但 $j$ 在任何时刻都不可能减到 $-1$,所以减小的次数也不可能超过 $n + m + 1$

代码实现

const int N = 1e6 + 10;
string s, p;
int pmt[N];
int nextval[N];

void GetPMT()
{
    for(int i = 1, j = 0; i < p.length(); ++i)
    {
        while(j && p[i] != p[j])
            j = pmt[j - 1];
        if(p[i] == p[j])
            ++j;
        pmt[i] = j;
        // pmt[i] = p[i] == p[j] ? ++j : j;
    }
}

void GetNextval()
{
    for(int i = 1, j = 0; i < p.length(); ++i)
    {
        while(j && p[i] != p[j])
            j = nextval[j - 1];
        if(p[i] == p[j])
            if(p[i + 1] == p[j + 1])
                nextval[i] = nextval[j++];
            else
                nextval[i] = ++j;
        else
            nextval[i] = j;
        // nextval[i] = p[i] == p[j] ? (p[i + 1] == p[j + 1] ? nextval[j++] : ++j) : j;
    }
}

void KMP()
{
    for(int i = 0, j = 0; i < s.length(); ++i)
    {
        while(j && s[i] != p[j])
            j = nextval[j - 1];
        if(s[i] == p[j])
            ++j;
        if(j == p.length())
        {
            cout << i - j + 2 << endl;
            j = nextval[j - 1];
        }
    }
}

int main()
{
    cin >> s >> p;
    GetNextval();
    KMP();
    return 0;
}