結果

問題 No.263 Common Palindromes Extra
ユーザー HIR180HIR180
提出日時 2020-04-29 02:34:12
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 1,720 ms / 2,000 ms
コード長 3,836 bytes
コンパイル時間 1,167 ms
コンパイル使用メモリ 99,556 KB
実行使用メモリ 126,152 KB
最終ジャッジ日時 2023-08-17 14:12:48
合計ジャッジ時間 9,519 ms
ジャッジサーバーID
(参考情報)
judge15 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 65 ms
75,044 KB
testcase_01 AC 33 ms
72,228 KB
testcase_02 AC 34 ms
72,456 KB
testcase_03 AC 137 ms
79,276 KB
testcase_04 AC 1,616 ms
108,480 KB
testcase_05 AC 1,720 ms
111,756 KB
testcase_06 AC 70 ms
75,340 KB
testcase_07 AC 1,239 ms
122,512 KB
testcase_08 AC 1,328 ms
124,504 KB
testcase_09 AC 456 ms
126,152 KB
testcase_10 AC 461 ms
126,108 KB
testcase_11 AC 381 ms
120,556 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <algorithm>
#include <vector>
#include <functional>
#include <cstring>
using namespace std;
#define rep(i, n) for(int i=0;i<n;i++)

typedef long long ll;

int par[1000005],ran2[1000005],a[1000005],b[1000005];
int sa[1000005],lcp[1000005],ran[1000005],rad[2000005],tmp[1000005],rui[1000005];
int n,k;
string S="",s1,s2;
ll res,cur;

void init()
{
	for(int i=0;i<1000005;i++) par[i] = i;
	for(int i=0;i<1000005;i++) ran2[i] = a[i] = b[i] = 0;
}

int find(int x)
{
	if(x == par[x]) return x;
	else return par[x] = find(par[x]);
}

void unite(int x,int y)
{
	x = find(x);
	y = find(y);
	
	if(x == y) return;
	
	if(ran2[x] < ran2[y])
	{
		par[x] = y;
		cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x];
		a[y] += a[x];
		b[y] += b[x];
	}
	else
	{
		par[y] = x;
		cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x];
		a[x] += a[y];
		b[x] += b[y];
		if(ran2[x] == ran2[y]) ran2[x]++;
	}
}

bool same(int x,int y)
{
	return find(x) == find(y);
}

void add_a(int x)
{
	x = find(x);
	cur += b[x];
	a[x]++;
}

void add_b(int x)
{
	x = find(x);
	cur += a[x];
	b[x]++;
}

void manacher()
{
	string str(2*n+1,'#');
	for(int i=0;i<n;i++) str[i*2+1] = S[i];
	
	int i = 0, j = 0;
	
	for(;i<2*n+1;)
	{
		while(i-j >= 0 && i+j < 2*n+1 && str[i-j] == str[i+j]) j++;
		rad[i] = j;
		int k = 1;
		while(i-k >= 0 && rad[i]-k > rad[i-k])
		{
			rad[i+k] = rad[i-k];
			++k;
		}
		i += k;
		j = max(j-k,0);
	}
}


void construct_sa(string S){
	int n = S.size();
	rep(i, n){
		sa[i] = i;
	}
	sort(sa, sa+n, [&](int a, int b){
		return S[a] == S[b] ? a > b : S[a] < S[b];
	});
	for(int i=1;i<n;i++) ran[sa[i]] = ran[sa[i-1]] + (S[sa[i-1]] != S[sa[i]]);
	
	for(int k = 1; k < n; k <<= 1){
		int nxt = 0;
		memset(rui, 0, sizeof(rui));
		rep(i, n) rui[ran[i]+1]++;
		rep(i, n) rui[i+1] += rui[i];
		//empty
		for(int i=n-k;i<n;i++){
			tmp[rui[ran[i]]++] = i;
		}
		rep(i, n){
			if(sa[i] < k) continue;
			tmp[rui[ran[sa[i]-k]]++] = sa[i]-k;
		}
		rep(i, n) sa[i] = tmp[i]; tmp[sa[0]] = 0;
		rep(i, n-1) tmp[sa[i+1]] = tmp[sa[i]] + (ran[sa[i]] != ran[sa[i+1]] || max(sa[i], sa[i+1])+k >= n || ran[sa[i]+k] != ran[sa[i+1]+k]);
		rep(i, n) ran[i] = tmp[i];
	}
	//consider empty string
	for(int i=n;i>0;i--) sa[i] = sa[i-1]; sa[0] = n;
	rep(i,n+1) ran[sa[i]] = i;
}
void construct_lcp()
{
	int h = 0;
	lcp[sa[0]] = 0;
	
	for(int i=0;i<n;i++)
	{
		int j = sa[ran[i]-1];
		if(h) h--;
		while(i+h < n && j+h < n && S[i+h] == S[j+h]) h++;
		lcp[ran[i]-1] = h;
	}
}
vector<int>query[1000005];
vector<int>in[1000005];

int main()
{
    ios::sync_with_stdio(0);
	cin >> s1 >> s2;
	S = s1+"$"+s2;
	n = S.size();
	construct_sa(S);
	construct_lcp();
	manacher();
	
	//odd
	cur = 0;
	init();
	for(int i=0;i<n;i++)
	{
		query[lcp[i]].push_back(i);
	}
	
	for(int i=0;i<n;i++)
	{
		int f = rad[i*2+1]-1;
		in[f].push_back(i);
	}
	int m = n;
	if(m%2==0)
	{
		m--;
	}
	for(int i=n;i>(m+1)/2;i--)
	{
		for(int j=0;j<query[i].size();j++)
		{
			unite(sa[query[i][j]],sa[query[i][j]+1]);
		}
	}
	for(int i=m;i>=1;i-=2)
	{
		for(int j=0;j<in[i].size();j++)
		{
			if(in[i][j] < s1.size()) add_a(in[i][j]);
			else if(in[i][j] > s1.size()) add_b(in[i][j]);
		}
		for(int j=0;j<query[(i+1)/2].size();j++)
		{
			unite(sa[query[(i+1)/2][j]],sa[query[(i+1)/2][j]+1]);
		}
		res += cur;
	}
	
	//even
	cur = 0;
	init();
	for(int i=0;i<1000005;i++) in[i].clear();
	for(int i=0;i<=2*n;i+=2)
	{
		int f = rad[i]-1;
		in[f].push_back(i/2);
	}
	m = n;
	if(m%2==1)
	{
		m--;
	}
	for(int i=n;i>m/2;i--)
	{
		for(int j=0;j<query[i].size();j++)
		{
			unite(sa[query[i][j]],sa[query[i][j]+1]);
		}
	}
	for(int i=m;i>=2;i-=2)
	{
		for(int j=0;j<in[i].size();j++)
		{
			if(in[i][j] < s1.size())  add_a(in[i][j]);
			else if(in[i][j] > s1.size()) add_b(in[i][j]);
		}
		for(int j=0;j<query[i/2].size();j++)
		{
			unite(sa[query[i/2][j]],sa[query[i/2][j]+1]);
		}
		res += cur;
	}
	cout << res << endl;
}
0