結果

問題 No.1649 Manhattan Square
ユーザー Drice27149
提出日時 2021-08-13 22:59:51
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 915 ms / 3,000 ms
コード長 2,934 bytes
コンパイル時間 882 ms
コンパイル使用メモリ 64,640 KB
実行使用メモリ 39,296 KB
最終ジャッジ日時 2024-10-03 22:11:18
合計ジャッジ時間 32,709 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 43
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <algorithm>
#include <map>
#include <cstring>
const long long mod = 998244353; 
int x[200005], y[200005];
int t[400005];
std::map<int,int> ids; int tot = 0;
struct Element {
	int x,y;
};
Element a[400005];
long long bit[8][400005];
int all;

void preWork(int n){
	std::sort(t+1,t+1+n);
	for(int i = 1; i <= n; i++){
		int u = t[i];
		if(!ids.count(u)) ids[u] = ++tot;
	}
}

void change(int p,long long v,int n,long long bit[]){
	while(p<=n){
		bit[p] += v;
		if(bit[p] >= mod) bit[p] -= mod;
		p += p&-p;
	}
}

long long ask(int p,long long bit[]){
	long long res = 0;
	while(p){
		res += bit[p];
		if(res >= mod) res -= mod;
		p -= p&-p;
	}
	return res;
}

void updateAns(int i,long long& ans, long long& x, long long& y, long long& xy,long long d){
	int u = ids[a[i].y];
	long long sx = ask(u,bit[0]);
	long long sy = ask(u,bit[1]);
	long long sxy = ask(u,bit[2]);
	long long cnt = ask(u,bit[3]);
	long long mxy = a[i].x*1ll*a[i].y%mod;
	long long mx = a[i].x, my = a[i].y;
	long long po = (cnt*mxy%mod - mx*sy%mod - my*sx%mod + sxy) % mod;
	if(po<0) po += mod;
	ans = (ans + 2ll*d*po)%mod;
		
	sx = x - sx; if(sx<0) sx += mod;
 	sy = y - sy; if(sy<0) sy += mod;
	sxy = xy - sxy; if(sxy<0) sxy += mod;
	if(d==1) cnt = (i-1-cnt);
	else cnt = (all-i-cnt);
	po = (cnt*mxy%mod - mx*sy%mod - my*sx%mod + sxy) % mod;
	ans = (ans - 2ll*d*po)%mod;
	if(ans < 0) ans += mod;
		
	change(u,mx,tot,bit[0]);
	change(u,my,tot,bit[1]);
	change(u,mxy,tot,bit[2]);
	change(u,1,tot,bit[3]);
		
	x = (x + mx)%mod;
	y = (y + my)%mod;
	xy = (xy + mxy)%mod;
}

long long power(long long a,long long b){
	long long res = 1;
	while(b){
		if(b&1) res = res*a%mod;
		a = a*a%mod;
		b /= 2;
	}
	return res;
}

int main(){
	int n;
	scanf("%d",&n);
	all = n;
	int size = 0;
	for(int i = 1; i <= n; i++){
		scanf("%d%d",&x[i],&y[i]);
		t[++size] = x[i];
		t[++size] = y[i];
		a[i].x = x[i], a[i].y = y[i];
	}
	preWork(size);
	std::sort(a+1,a+1+n,[](Element& u, Element& v){
		return u.x < v.x;
	});
	// 0: x, 1: y, 2: x*y
	// 3: cnt
	long long ans = 0;
	long long x = 0, y = 0, xy = 0;
	for(int i = 1; i <= n; i++){
		updateAns(i,ans,x,y,xy,1ll);
	}
	//printf("ans = %lld\n",ans);
	for(int i = 0; i < 4; i++) memset(bit[i],0,sizeof(bit[i]));
	x = 0, y = 0, xy = 0;
	long long sxx = 0, syy = 0;
	for(int i = n; i >= 1; i--){
		updateAns(i,ans,x,y,xy,-1ll);
		sxx = (sxx + a[i].x*1ll*a[i].x)%mod;
		syy = (syy + a[i].y*1ll*a[i].y)%mod;
		//printf("i = %d, ans = %lld\n",i,ans);
	}
	//printf("ans = %lld\n",ans);
	//printf("x = %lld, y = %lld, xy = %lld\n",x,y,xy);
	for(int i = 1; i <= n; i++){
		long long xx = (a[i].x*1ll*a[i].x)%mod;
		long long cnt = n;
		ans = (ans + xx*cnt%mod + sxx - 2ll*a[i].x*x)%mod;
		if(ans < 0) ans += mod;
		long long yy = (a[i].y*1ll*a[i].y)%mod;
		ans = (ans + yy*cnt%mod + syy - 2ll*a[i].y*y)%mod;
		if(ans < 0) ans += mod;
	}
	long long inv = power(2,mod-2);
	printf("%lld\n",ans*inv%mod);
	return 0;
}
0