結果

問題 No.263 Common Palindromes Extra
ユーザー heno239heno239
提出日時 2020-05-01 16:59:54
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 5,384 bytes
コンパイル時間 2,508 ms
コンパイル使用メモリ 145,824 KB
実行使用メモリ 100,788 KB
最終ジャッジ日時 2024-06-06 16:57:05
合計ジャッジ時間 8,942 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 19 ms
6,816 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 AC 47 ms
8,020 KB
testcase_04 AC 231 ms
27,248 KB
testcase_05 AC 196 ms
26,892 KB
testcase_06 AC 25 ms
6,816 KB
testcase_07 AC 969 ms
58,428 KB
testcase_08 AC 909 ms
58,072 KB
testcase_09 TLE -
testcase_10 TLE -
testcase_11 AC 145 ms
26,788 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")
#include<iostream>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<algorithm>
#include<functional>
#include<iomanip>
#include<queue>
#include<ciso646>
#include<random>
#include<map>
#include<set>
#include<bitset>
#include<stack>
#include<unordered_map>
#include<utility>
#include<cassert>
#include<complex>
#include<numeric>
using namespace std;

//#define int long long
typedef long long ll;

typedef unsigned long long ul;
typedef unsigned int ui;
const ll mod = 786433;
const ll INF = mod * mod;
typedef pair<int, int>P;
#define stop char nyaa;cin>>nyaa;
#define rep(i,n) for(int i=0;i<n;i++)
#define per(i,n) for(int i=n-1;i>=0;i--)
#define Rep(i,sta,n) for(int i=sta;i<n;i++)
#define rep1(i,n) for(int i=1;i<=n;i++)
#define per1(i,n) for(int i=n;i>=1;i--)
#define Rep1(i,sta,n) for(int i=sta;i<=n;i++)
#define all(v) (v).begin(),(v).end()
typedef pair<ll, ll> LP;
typedef long double ld;
typedef pair<ld, ld> LDP;
const ld eps = 1e-12;
const ld pi = acos(-1.0);

ll mod_pow(ll x, ll n, ll m = mod) {
	ll res = 1;
	while (n > 0) {
		if (n & 1)res = res * x%m;
		x = x * x%m; n >>= 1;
	}
	return res;
}
struct rolling_hash {
private:
	int sz;
	vector<LP> node;
	vector<LP> r;
	ll t = 999999937;
	ll m = 1000000009;
	ll m2 = 1000000007;
	ll invt, invt2;
public:
	rolling_hash(const string &s) {
		int n = s.length();
		sz = n;
		node.resize(n + 1); r.resize(n + 1);
		node[0] = { 0,0 };

		invt = mod_pow(t, m - 2, m);
		invt2 = mod_pow(t, m2 - 2, m2);
		ll a = 1;
		ll a2 = 1;
		rep(i, n) {
			//r[i] = a;
			int z = s[i] - 'a';
			node[i + 1].first = node[i].first + a * z;
			node[i + 1].first %= m;
			node[i + 1].second = node[i].second + a2 * z;
			node[i + 1].second %= m2;
			a = a * t%m;
			a2 = a2 * t%m2;
		}
		a = 1; a2 = 1;
		rep(i, n) {
			r[i].first = a; a = invt * a%m;
			r[i].second = a2; a2 = invt2 * a2%m2;
		}
	}
	P calc(int le, int len) {
		LP ret = { node[le + len].first - node[le].first,node[le + len].second - node[le].second };
		if (ret.first < 0)ret.first += m;
		if (ret.second < 0)ret.second += m2;
		ret.first = ret.first*r[le].first % m;
		ret.second = ret.second*r[le].second % m2;
		return ret;
	}
};

void manacher(const string &s, vector<int> &r) {
	r.resize(s.size());
	int i = 0, j = 0;
	while (i < s.size()) {
		while (i - j >= 0 && i + j < s.size() && s[i - j] == s[i + j])++j;
		r[i] = j;
		int k = 1;
		while (i - k >= 0 && i + k < s.size() && k + r[i - k] < j)r[i + k] = r[i - k], ++k;
		i += k; j -= k;
	}
}

vector<pair<P, ll>> ps(string &s) {
	int n = s.size();
	rolling_hash rs(s);

	map<P, int> used;

	vector<pair<P, P>> edges;

	//odd
	{
		vector<int> c;
		manacher(s, c);
		rep(i, n) {
			int le = i + 1 - c[i];
			int ri = i - 1 + c[i];
			P cur = rs.calc(le, ri - le + 1);
			bool pass = false;
			if (used.find(cur) != used.end())pass = true;
			used[cur]++;
			if (pass)continue;
			while (le + 1 <= ri - 1) {
				le++; ri--;
				P nex = rs.calc(le, ri - le + 1);
				edges.push_back({ cur,nex });
				if (used.find(nex) != used.end())break;
				used[nex] = 0;
				cur = nex;
			}
		}
	}
	//even
	{
		string ori; ori.push_back('#');
		rep(i, n) {
			ori.push_back(s[i]);
			ori.push_back('#');
		}
		vector<int> c;
		manacher(ori, c);
		rep(i, n) {
			int le = i - c[2 * i] / 2;
			int ri = i - 1 + c[2 * i] / 2;
			if (le > ri)continue;
			P cur = rs.calc(le, ri - le + 1);
			bool pass = false;
			if (used.find(cur) != used.end())pass = true;
			used[cur]++;
			if (pass)continue;
			while (le + 1 <= ri - 1) {
				le++; ri--;
				P nex = rs.calc(le, ri - le + 1);
				edges.push_back({ cur,nex });
				if (used.find(nex) != used.end())break;
				used[nex] = 0;
				cur = nex;
			}
		}
	}


	vector<P> exis;
	vector<ll> dp;
	for (pair<P, int> p : used) {
		exis.push_back(p.first);
		dp.push_back(p.second);
	}

	vector<int> cnt(exis.size());
	vector<int> nex(exis.size(), -1);


	for (pair<P, P> p : edges) {
		int l = lower_bound(all(exis), p.first) - exis.begin();
		int r = lower_bound(all(exis), p.second) - exis.begin();
		nex[l] = r;
		cnt[r]++;
	}


	queue<int> q;
	rep(i, exis.size())if (cnt[i] == 0)q.push(i);
	while (!q.empty()) {
		int id = q.front(); q.pop();
		if (nex[id] >= 0) {
			int to = nex[id];
			dp[to] += dp[id];
			cnt[to]--;
			if (cnt[to] == 0) {
				q.push(to);
			}
		}
	}

	vector<pair<P, ll>> res;
	rep(i, exis.size()) {
		res.push_back({ exis[i],dp[i] });
	}
	return res;
}
void solve() {
	string s, t; cin >> s >> t;
	vector<pair<P, ll>> dps = ps(s), dpt = ps(t);
	/*for (pair<P, ll> p: dps) {
	cout << p.first.first << " " << p.first.second << " " << p.second << "\n";
	}
	for (pair<P, ll> p : dpt) {
	cout << p.first.first << " " << p.first.second << " " << p.second << "\n";
	}*/
	ll ans = 0;
	int id1 = 0, id2 = 0;
	while (id1 < dps.size() && id2 < dpt.size()) {
		P mi = min(dps[id1].first, dpt[id2].first);
		ll s1 = 0, s2 = 0;
		while (id1 < dps.size() && dps[id1].first == mi) {
			s1 += dps[id1].second; id1++;
		}
		while (id2 < dpt.size() && dpt[id2].first == mi) {
			s2 += dpt[id2].second; id2++;
		}
		ans += s1 * s2;
	}
	cout << ans << "\n";
}

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	//cout << fixed << setprecision(7);
	//init_f();
	//init();
	//experi();
	//int t; cin >> t; rep(i, t)solve();
	solve();
	stop
		return 0;
}
0