結果
問題 | No.263 Common Palindromes Extra |
ユーザー | HIR180 |
提出日時 | 2020-04-29 02:38:38 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 1,729 ms / 2,000 ms |
コード長 | 3,867 bytes |
コンパイル時間 | 1,345 ms |
コンパイル使用メモリ | 106,504 KB |
実行使用メモリ | 126,288 KB |
最終ジャッジ日時 | 2024-05-04 21:02:11 |
合計ジャッジ時間 | 6,163 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 89 ms
73,160 KB |
testcase_01 | AC | 59 ms
69,888 KB |
testcase_02 | AC | 60 ms
70,016 KB |
testcase_03 | AC | 161 ms
78,020 KB |
testcase_04 | AC | 1,712 ms
108,748 KB |
testcase_05 | AC | 1,729 ms
111,708 KB |
testcase_06 | AC | 99 ms
73,728 KB |
testcase_07 | AC | 1,314 ms
122,704 KB |
testcase_08 | AC | 1,402 ms
124,668 KB |
testcase_09 | AC | 461 ms
126,288 KB |
testcase_10 | AC | 466 ms
126,288 KB |
testcase_11 | AC | 396 ms
120,724 KB |
ソースコード
#pragma GCC optimize("Ofast") #include <iostream> #include <algorithm> #include <vector> #include <functional> #include <cstring> using namespace std; #define rep(i, n) for(int i=0;i<n;i++) typedef long long ll; int par[1000005],ran2[1000005],a[1000005],b[1000005]; int sa[1000005],lcp[1000005],ran[1000005],rad[2000005],tmp[1000005],rui[1000005]; int n,k; string S="",s1,s2; ll res,cur; void init() { for(int i=0;i<1000005;i++) par[i] = i; for(int i=0;i<1000005;i++) ran2[i] = a[i] = b[i] = 0; } int find(int x) { if(x == par[x]) return x; else return par[x] = find(par[x]); } void unite(int x,int y) { x = find(x); y = find(y); if(x == y) return; if(ran2[x] < ran2[y]) { par[x] = y; cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x]; a[y] += a[x]; b[y] += b[x]; } else { par[y] = x; cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x]; a[x] += a[y]; b[x] += b[y]; if(ran2[x] == ran2[y]) ran2[x]++; } } bool same(int x,int y) { return find(x) == find(y); } void add_a(int x) { x = find(x); cur += b[x]; a[x]++; } void add_b(int x) { x = find(x); cur += a[x]; b[x]++; } void manacher() { string str(2*n+1,'#'); for(int i=0;i<n;i++) str[i*2+1] = S[i]; int i = 0, j = 0; for(;i<2*n+1;) { while(i-j >= 0 && i+j < 2*n+1 && str[i-j] == str[i+j]) j++; rad[i] = j; int k = 1; while(i-k >= 0 && rad[i]-k > rad[i-k]) { rad[i+k] = rad[i-k]; ++k; } i += k; j = max(j-k,0); } } void construct_sa(string S){ int n = S.size(); rep(i, n){ sa[i] = i; } sort(sa, sa+n, [&](int a, int b){ return S[a] == S[b] ? a > b : S[a] < S[b]; }); for(int i=1;i<n;i++) ran[sa[i]] = ran[sa[i-1]] + (S[sa[i-1]] != S[sa[i]]); for(int k = 1; k < n; k <<= 1){ int nxt = 0; memset(rui, 0, sizeof(rui)); rep(i, n) rui[ran[i]+1]++; rep(i, n) rui[i+1] += rui[i]; //empty for(int i=n-k;i<n;i++){ tmp[rui[ran[i]]++] = i; } rep(i, n){ if(sa[i] < k) continue; tmp[rui[ran[sa[i]-k]]++] = sa[i]-k; } rep(i, n) sa[i] = tmp[i]; tmp[sa[0]] = 0; rep(i, n-1) tmp[sa[i+1]] = tmp[sa[i]] + (ran[sa[i]] != ran[sa[i+1]] || max(sa[i], sa[i+1])+k >= n || ran[sa[i]+k] != ran[sa[i+1]+k]); rep(i, n) ran[i] = tmp[i]; } //consider empty string for(int i=n;i>0;i--) sa[i] = sa[i-1]; sa[0] = n; rep(i,n+1) ran[sa[i]] = i; } void construct_lcp() { int h = 0; lcp[sa[0]] = 0; for(int i=0;i<n;i++) { int j = sa[ran[i]-1]; if(h) h--; while(i+h < n && j+h < n && S[i+h] == S[j+h]) h++; lcp[ran[i]-1] = h; } } vector<int>query[1000005]; vector<int>in[1000005]; int main() { ios::sync_with_stdio(0); cin >> s1 >> s2; S = s1+"$"+s2; n = S.size(); construct_sa(S); construct_lcp(); manacher(); //odd cur = 0; init(); for(int i=0;i<n;i++) { query[lcp[i]].push_back(i); } for(int i=0;i<n;i++) { int f = rad[i*2+1]-1; in[f].push_back(i); } int m = n; if(m%2==0) { m--; } for(int i=n;i>(m+1)/2;i--) { for(int j=0;j<query[i].size();j++) { unite(sa[query[i][j]],sa[query[i][j]+1]); } } for(int i=m;i>=1;i-=2) { for(int j=0;j<in[i].size();j++) { if(in[i][j] < s1.size()) add_a(in[i][j]); else if(in[i][j] > s1.size()) add_b(in[i][j]); } for(int j=0;j<query[(i+1)/2].size();j++) { unite(sa[query[(i+1)/2][j]],sa[query[(i+1)/2][j]+1]); } res += cur; } //even cur = 0; init(); for(int i=0;i<1000005;i++) in[i].clear(); for(int i=0;i<=2*n;i+=2) { int f = rad[i]-1; in[f].push_back(i/2); } m = n; if(m%2==1) { m--; } for(int i=n;i>m/2;i--) { for(int j=0;j<query[i].size();j++) { unite(sa[query[i][j]],sa[query[i][j]+1]); } } for(int i=m;i>=2;i-=2) { for(int j=0;j<in[i].size();j++) { if(in[i][j] < s1.size()) add_a(in[i][j]); else if(in[i][j] > s1.size()) add_b(in[i][j]); } for(int j=0;j<query[i/2].size();j++) { unite(sa[query[i/2][j]],sa[query[i/2][j]+1]); } res += cur; } cout << res << endl; }