题目描述
对于一对字符串 $(s_1,s_2)$,若 $s_1$ 的长度为奇数的子串 $(l,r)$ 满足 $(l,r)$ 是回文的,那么 $s_1$ 的“分数”会增加 $s_2$ 在 $(l,r)$ 中出现的次数。
现在给出一对 $(s_1,s_2)$,请计算出 $s_1$ 的“分数”。
答案对 $2 ^ {32}$ 取模。
输入格式
第一行两个整数,$n,m$,表示 $s_1$ 的长度和 $s_2$ 的长度。
第二行两个字符串,$s_1,s_2$。
输出格式
一行一个整数,表示 $s_1$ 的分数。
样例 #1
样例输入 #1
10 2
ccbccbbcbb bc
样例输出 #1
4
样例 #2
样例输入 #2
20 2
cbcaacabcbacbbabacca ba
样例输出 #2
4
提示
【样例解释】
对于样例一:
子串 $(1,5)$ 中 $s_2$ 出现了一次,子串 $(2,4)$ 中 $s_2$ 出现了一次。
子串 $(7,9)$ 中 $s_2$ 出现了一次,子串 $(6,10)$ 中 $s_2$ 出现了一次。
【数据范围】
- 对于 $100\%$ 的数据:$1 \le n,m \le 3 \times 10 ^ 6$,字符串中的字符都是小写字母。
分析
KMP 算法预处理出每一个字串 $s_2$ 的位置,Manacher 算法求出每个点的回文半径,然后我们考虑每个点的贡献
对于每一个点 $i$ 来说,其对最后总分数的贡献是以该点为中心的最长的回文串中 $s_2$ 出现的次数 $+$ 以该点为中心的次长的回文串中 $s_2$ 出现的次数 $+\cdots+$ 该点组成的长度为 $1$ 的回文串中 $s_2$ 出现的次数
如回文串 abcbcba
,我们查找串 $s_2 = bc$,那么我们用 $p[i]$ 记录从 $s_1[0]$ 到 $s_1[i]$ 中 $s_2$ 出现的次数,那么对于最长的回文串 abcbcba
,只要用 $p[6] - p[0]$ 就可以获取这个串中 $s_2$ 出现的次数
对于上面的整个回文串来说,最后的结果就是 $p[6] - p[0] + p[5] - p[1] + p[4] - p[2] + p[3] - p[3]$,容易想到,可以用前缀和对其进行优化。所以,我们求出 $p[i]$ 的前缀和 $sum[i]$,对于每一个点来说,我们只要求出 $sum[i + d[i] - 1] - sum[i] + sum[i - 1] - sum[i - d[i]]$ 的值并求和即可
算法实现
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#define ll long long
using namespace std;
const int N = 4e6 + 10;
char s[N], p[N];
ll nextval[N], t[N], plc[N], n, m, ans;
void sum()
{
for (ll i = 2; i <= n; ++i)
plc[i] += plc[i - 1];
}
inline void getNext()
{
int j = 0;
for (int i = 2; i <= n; ++i)
{
while (j && p[j + 1] != p[i])
j = nextval[j];
if (p[j + 1] == p[i])
++j;
nextval[i] = j;
}
}
inline void KMP()
{
getNext();
long long int j = 0;
for (register long long int i = 1; i <= n; ++i)
{
while (j && p[j + 1] != s[i])
j = nextval[j];
if (p[j + 1] == s[i])
++j;
if (j == m)
{
j = nextval[j];
++plc[i - m + 1];
}
}
}
void manacher()
{
for (int i = 1, l = 1, r = 0; i <= n; ++i)
{
ll ml = l + r - i;
t[i] = max(min(t[ml], 1LL * r - i + 1), 1LL);
if (ml - t[ml] < l)
{
while (i - t[i] > 0 && i + t[i] <= n && s[i + t[i]] == s[i - t[i]])
t[i]++;
l = i - t[i] + 1, r = i + t[i] - 1;
}
}
}
void getAns()
{
ll mid = (m + 1) >> 1;
for (int i = mid; i <= n; ++i)
{
if ((t[i] << 1) - 1 < m)
continue;
ans += (plc[i - m + t[i]] - plc[i - m + mid - 1]);
ans -= (plc[i - mid] - plc[i - t[i] - 1]);
}
ll MOD = pow(2, 32);
cout << ans % MOD << endl;
}
void solve()
{
scanf("%lld %lld", &n, &m);
scanf("%s %s", s + 1, p + 1);
s[0] = '#';
KMP();
sum();
sum();
manacher();
getAns();
}
int main()
{
int T = 1;
// cin >> T;
while (T--)
{
solve();
}
return 0;
}