結果

問題 No.1796 木上のクーロン
ユーザー 👑 NachiaNachia
提出日時 2021-12-19 01:08:15
言語 Java21
(openjdk 21)
結果
AC  
実行時間 6,481 ms / 10,000 ms
コード長 5,234 bytes
コンパイル時間 2,936 ms
コンパイル使用メモリ 84,804 KB
実行使用メモリ 125,840 KB
最終ジャッジ日時 2023-10-20 03:04:14
合計ジャッジ時間 68,198 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 141 ms
57,596 KB
testcase_01 AC 142 ms
57,312 KB
testcase_02 AC 147 ms
57,436 KB
testcase_03 AC 140 ms
57,564 KB
testcase_04 AC 142 ms
57,716 KB
testcase_05 AC 140 ms
57,768 KB
testcase_06 AC 143 ms
57,248 KB
testcase_07 AC 139 ms
57,592 KB
testcase_08 AC 227 ms
58,172 KB
testcase_09 AC 242 ms
58,428 KB
testcase_10 AC 245 ms
60,216 KB
testcase_11 AC 242 ms
58,624 KB
testcase_12 AC 246 ms
60,312 KB
testcase_13 AC 225 ms
58,340 KB
testcase_14 AC 238 ms
58,444 KB
testcase_15 AC 266 ms
60,468 KB
testcase_16 AC 261 ms
60,640 KB
testcase_17 AC 249 ms
60,172 KB
testcase_18 AC 248 ms
60,232 KB
testcase_19 AC 251 ms
60,308 KB
testcase_20 AC 1,615 ms
79,924 KB
testcase_21 AC 1,593 ms
80,148 KB
testcase_22 AC 2,244 ms
91,812 KB
testcase_23 AC 1,979 ms
91,156 KB
testcase_24 AC 2,560 ms
115,528 KB
testcase_25 AC 2,539 ms
115,072 KB
testcase_26 AC 3,081 ms
112,900 KB
testcase_27 AC 3,051 ms
113,112 KB
testcase_28 AC 6,481 ms
122,060 KB
testcase_29 AC 6,417 ms
111,172 KB
testcase_30 AC 2,766 ms
112,448 KB
testcase_31 AC 2,623 ms
111,556 KB
testcase_32 AC 3,314 ms
112,228 KB
testcase_33 AC 4,489 ms
125,840 KB
testcase_34 AC 4,356 ms
123,524 KB
testcase_35 AC 4,829 ms
112,340 KB
testcase_36 AC 4,772 ms
113,280 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.io.PrintWriter;
import java.util.Scanner;

public class Main {

	static final int MOD = 998244353;
	static final long NTTg = 3;

	static long powm(long a, long i) {
		if(i == 0) return 1;
		long r = powm(a*a%MOD,i/2);
		if(i%2 == 1) r=r*a%MOD;
		return r;
	}

	static final long invNTTg = powm(NTTg, MOD-2);

	static void NTT(long[] A, long g) {
		int N = A.length;
		for(int i=0, j=0; j<N; j++) {
			if(i < j) {
				long b = A[i];
				A[i] = A[j];
				A[j] = b;
			}
			for(int k=N>>1; k>(i^=k); k>>=1);
		}
		for(int i=1; i<N; i<<=1) {
			long q = powm(g,(MOD-1)/i/2), qj = 1;
			for(int j=0; j<i; j++) {
				for(int k=j; k<N; k+=i*2) {
					long l = A[k];
					long r = A[k+i] * qj % MOD;
					A[k] = l+r;
					if(A[k] >= MOD) A[k] -= MOD;
					A[k+i] = l+MOD-r;
					if(A[k+i] >= MOD) A[k+i] -= MOD;
				}
				qj = qj * q % MOD;
			}
		}
	}

	int N;
	long Q[];
	int E[][];

	int cdep[];
	int cp[];
	int cbfs[];

	void centroid_decomposition() {
		cdep = new int[N];
		cp = new int[N]; for(int i=0; i<N; i++) cp[i] = -1;
		cbfs = new int[N];
		int P[] = new int[N]; for(int i=0; i<N; i++) P[i] = -1;
		int I[] = new int[N];
		int Iidx = 0; I[Iidx++] = 0;
		for(int p : I) for(int e : E[p]) if(P[p] != e) { P[e] = p; I[Iidx++] = e; }
		int Z[] = new int[N]; for(int i=0; i<N; i++) Z[i] = 1;
		for(int i=N-1; i>=1; i--) Z[P[I[i]]] += Z[I[i]];

		int cdP[] = new int[N];
		Iidx = 0;
		cdP[Iidx] = -1;
		I[Iidx++] = 0;
		for(int i=0; i<N; i++) {
			int s = I[i];
			int par = cdP[i];
			while(true) {
				int nx = -1;
				for(int e : E[s]) if(Z[e]*2 > Z[s]) nx = e;
				if(nx == -1) break;
				Z[s] -= Z[nx]; Z[nx] += Z[s];
				s = nx;
			}
			cbfs[i] = s;
			Z[s] = 0;
			if(par != -1) { cdep[s] = cdep[par]+1; cp[s] = par; }
			for(int e : E[s]) if(Z[e] != 0) {
				cdP[Iidx] = s;
				I[Iidx++] = e;
			}
		}
	}



	long k0;
	int max_ntt_size_log;
	int max_ntt_size;
	long inv_mod[];
	long C[];
	long nttC[][];

	void input() {
		Scanner sc = new Scanner(System.in);
		N = sc.nextInt();
		Q = new long[N];
		for(int i=0; i<N; i++) Q[i] = (long)sc.nextInt();
		{
			int edge_idx[] = new int[N];
			int edge_u[] = new int[N-1];
			int edge_v[] = new int[N-1];
			for(int i=0; i<N-1; i++) {
				int u = sc.nextInt() - 1;
				int v = sc.nextInt() - 1;
				edge_idx[u]++;
				edge_idx[v]++;
				edge_u[i] = u;
				edge_v[i] = v;
			}
			E = new int[N][];
			for(int i=0; i<N; i++) E[i] = new int[edge_idx[i]];
			for(int i=0; i<N-1; i++) {
				int u = edge_u[i];
				int v = edge_v[i];
				E[u][--edge_idx[u]] = v;
				E[v][--edge_idx[v]] = u;
			}
		}
		sc.close();
	}

	void init_modint() {
		max_ntt_size_log = 0;
		while((1 << max_ntt_size_log) < N+6) max_ntt_size_log++; max_ntt_size_log++;
		max_ntt_size = 1 << max_ntt_size_log;
		k0 = 1;
		for(int i=1; i<=N; i++) k0 = k0 * i % MOD;
		k0 = k0 * k0 % MOD;

		inv_mod = new long[max_ntt_size+1];
		inv_mod[1] = 1;
		for(int i=2; i<=max_ntt_size; i++) inv_mod[i] = MOD - MOD / i * inv_mod[MOD%i] % MOD;

		C = new long[max_ntt_size];
		for(int i=0; i<max_ntt_size; i++) C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD;

		nttC = new long[max_ntt_size_log+1][];
		for(int d=0; d<=max_ntt_size_log; d++) {
			long inv_ntt_size = powm(1<<d, MOD-2);
			nttC[d] = new long[1<<d];
			for(int i=0; i<1<<d; i++) nttC[d][i] = C[i] * inv_ntt_size % MOD;
			NTT(nttC[d], NTTg);
		}
	}

	int bfsbuf_dist[];
	int bfsbuf_parent[];
	int bfsbuf_I[];
	int bfsbuf_I_size;
	long[] sigma_tree(int s, int dep) {
		if(cdep[s] < dep) return new long[1];
		bfsbuf_dist[s] = 0;
		bfsbuf_parent[s] = -1;
		bfsbuf_I_size = 0;
		bfsbuf_I[bfsbuf_I_size++] = s;
		int maxdist = 0;

		for(int i=0; i<bfsbuf_I_size; i++) {
			int p = bfsbuf_I[i];
			maxdist = Math.max(maxdist, bfsbuf_dist[p]);
			for(int e : E[p]) if(bfsbuf_parent[p] != e) {
				if(cdep[e] < dep) continue;
				bfsbuf_parent[e] = p;
				bfsbuf_dist[e] = bfsbuf_dist[p] + 1;
				bfsbuf_I[bfsbuf_I_size++] = e;
			}
		}

		long dfreq[] = new long[maxdist+1];
		for(int i=0; i<bfsbuf_I_size; i++) {
			int p = bfsbuf_I[i];
			int d = bfsbuf_dist[p];
			dfreq[d] = (dfreq[d] + Q[p]) % MOD;
		}

		int Z = 1;
		int d = 0;
		while(Z < dfreq.length+2) { Z*=2; d++; }
		long res[] = new long[Z*2];
		for(int i=0; i<dfreq.length; i++) res[Z-i] = dfreq[i];
		NTT(res, NTTg);
		for(int i=0; i<Z*2; i++) res[i] = res[i] * nttC[d+1][i] % MOD;
		NTT(res, invNTTg);
		for(int i=0; i<Z; i++) res[i] = res[i+Z];
		return res;
	}

	void solve() {
		input();
		init_modint();
		centroid_decomposition();

		bfsbuf_dist = new int[N];
		bfsbuf_parent = new int[N];
		bfsbuf_I = new int[N];

		long ans[] = new long[N];
		for(int s=0; s<N; s++) {
			int dep_s = cdep[s];
			long[] sigma_s = sigma_tree(s, dep_s);
			for(int nx : E[s]) if(cdep[nx] > dep_s) {
				long[] sigma_nx = sigma_tree(nx, dep_s + 1);
				for(int i=0; i<bfsbuf_I_size; i++) {
					int p = bfsbuf_I[i];
					int d = bfsbuf_dist[p] + 1;
					ans[p] += sigma_s[d] - sigma_nx[d+1] + MOD;
				}
			}
			ans[s] += sigma_s[0];
		}
		for(int i=0; i<N; i++) ans[i] %= MOD;

		PrintWriter wt = new PrintWriter(System.out);
		for(int i=0; i<N; i++) {
			wt.println(ans[i]);
		}
		wt.flush();
	}



	public static void main(String[] args) {
		Main solver = new Main();
		solver.solve();
	}
}
0