P6216 回文匹配

·   ·   ·   ·

  ·   ·


P6216 回文匹配

题目描述

对于一对字符串 $(s_1,s_2)$,若 $s_1$ 的长度为奇数的子串 $(l,r)$ 满足 $(l,r)$ 是回文的,那么 $s_1$ 的“分数”会增加 $s_2$ 在 $(l,r)$ 中出现的次数。

现在给出一对 $(s_1,s_2)$,请计算出 $s_1$ 的“分数”。

答案对 $2 ^ {32}$ 取模。

输入格式

第一行两个整数,$n,m$,表示 $s_1$ 的长度和 $s_2$ 的长度。

第二行两个字符串,$s_1,s_2$。

输出格式

一行一个整数,表示 $s_1$ 的分数。

样例 #1

样例输入 #1
10 2
ccbccbbcbb bc
样例输出 #1
4

样例 #2

样例输入 #2
20 2
cbcaacabcbacbbabacca ba
样例输出 #2
4

提示

【样例解释】

对于样例一:

子串 $(1,5)$ 中 $s_2$ 出现了一次,子串 $(2,4)$ 中 $s_2$ 出现了一次。

子串 $(7,9)$ 中 $s_2$ 出现了一次,子串 $(6,10)$ 中 $s_2$ 出现了一次。

【数据范围】

  • 对于 $100\%$ 的数据:$1 \le n,m \le 3 \times 10 ^ 6$,字符串中的字符都是小写字母。

分析

KMP 算法预处理出每一个字串 $s_2$ 的位置,Manacher 算法求出每个点的回文半径,然后我们考虑每个点的贡献

对于每一个点 $i$ 来说,其对最后总分数的贡献是以该点为中心的最长的回文串中 $s_2$ 出现的次数 $+$ 以该点为中心的次长的回文串中 $s_2$ 出现的次数 $+\cdots+$ 该点组成的长度为 $1$ 的回文串中 $s_2$ 出现的次数

如回文串 abcbcba,我们查找串 $s_2 = bc$,那么我们用 $p[i]$ 记录从 $s_1[0]$ 到 $s_1[i]$ 中 $s_2$ 出现的次数,那么对于最长的回文串 abcbcba ,只要用 $p[6] - p[0]$ 就可以获取这个串中 $s_2$ 出现的次数

对于上面的整个回文串来说,最后的结果就是 $p[6] - p[0] + p[5] - p[1] + p[4] - p[2] + p[3] - p[3]$,容易想到,可以用前缀和对其进行优化。所以,我们求出 $p[i]$ 的前缀和 $sum[i]$,对于每一个点来说,我们只要求出 $sum[i + d[i] - 1] - sum[i] + sum[i - 1] - sum[i - d[i]]$ 的值并求和即可

算法实现

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define ll long long

using namespace std;

const int N = 4e6 + 10;
char s[N], p[N];
ll nextval[N], t[N], plc[N], n, m, ans;

void sum()
{
    for (ll i = 2; i <= n; ++i)
        plc[i] += plc[i - 1];
}

inline void getNext()
{
    int j = 0;
    for (int i = 2; i <= n; ++i)
    {
        while (j && p[j + 1] != p[i])
            j = nextval[j];
        if (p[j + 1] == p[i])
            ++j;
        nextval[i] = j;
    }
}

inline void KMP()
{
    getNext();
    long long int j = 0;
    for (register long long int i = 1; i <= n; ++i)
    {
        while (j && p[j + 1] != s[i])
            j = nextval[j];
        if (p[j + 1] == s[i])
            ++j;
        if (j == m)
        {
            j = nextval[j];
            ++plc[i - m + 1];
        }
    }
}

void manacher()
{
    for (int i = 1, l = 1, r = 0; i <= n; ++i)
    {
        ll ml = l + r - i;
        t[i] = max(min(t[ml], 1LL * r - i + 1), 1LL);
        if (ml - t[ml] < l)
        {
            while (i - t[i] > 0 && i + t[i] <= n && s[i + t[i]] == s[i - t[i]])
                t[i]++;
            l = i - t[i] + 1, r = i + t[i] - 1;
        }
    }
}

void getAns()
{
    ll mid = (m + 1) >> 1;
    for (int i = mid; i <= n; ++i)
    {
        if ((t[i] << 1) - 1 < m)
            continue;
        ans += (plc[i - m + t[i]] - plc[i - m + mid - 1]);
        ans -= (plc[i - mid] - plc[i - t[i] - 1]);
    }
    ll MOD = pow(2, 32);
    cout << ans % MOD << endl;
}

void solve()
{
    scanf("%lld %lld", &n, &m);
    scanf("%s %s", s + 1, p + 1);
    s[0] = '#';
    KMP();
    sum();
    sum();
    manacher();
    getAns();
}

int main()
{
    int T = 1;
    // cin >> T;
    while (T--)
    {
        solve();
    }
    return 0;
}