結果
問題 | No.1796 木上のクーロン |
ユーザー | 👑 Nachia |
提出日時 | 2021-12-18 17:12:51 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 7,035 ms / 10,000 ms |
コード長 | 3,764 bytes |
コンパイル時間 | 178 ms |
コンパイル使用メモリ | 82,404 KB |
実行使用メモリ | 233,728 KB |
最終ジャッジ日時 | 2024-09-19 22:46:44 |
合計ジャッジ時間 | 55,252 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 37 ms
53,940 KB |
testcase_01 | AC | 37 ms
53,440 KB |
testcase_02 | AC | 37 ms
54,580 KB |
testcase_03 | AC | 36 ms
54,468 KB |
testcase_04 | AC | 37 ms
54,048 KB |
testcase_05 | AC | 35 ms
53,980 KB |
testcase_06 | AC | 34 ms
53,264 KB |
testcase_07 | AC | 33 ms
53,624 KB |
testcase_08 | AC | 82 ms
76,536 KB |
testcase_09 | AC | 98 ms
76,376 KB |
testcase_10 | AC | 97 ms
76,364 KB |
testcase_11 | AC | 88 ms
76,840 KB |
testcase_12 | AC | 106 ms
77,648 KB |
testcase_13 | AC | 80 ms
76,364 KB |
testcase_14 | AC | 95 ms
76,460 KB |
testcase_15 | AC | 121 ms
77,412 KB |
testcase_16 | AC | 116 ms
77,484 KB |
testcase_17 | AC | 63 ms
72,912 KB |
testcase_18 | AC | 68 ms
73,624 KB |
testcase_19 | AC | 104 ms
76,660 KB |
testcase_20 | AC | 614 ms
100,940 KB |
testcase_21 | AC | 652 ms
100,968 KB |
testcase_22 | AC | 1,346 ms
125,788 KB |
testcase_23 | AC | 1,288 ms
124,392 KB |
testcase_24 | AC | 1,975 ms
153,144 KB |
testcase_25 | AC | 2,209 ms
153,772 KB |
testcase_26 | AC | 2,930 ms
181,512 KB |
testcase_27 | AC | 2,650 ms
181,232 KB |
testcase_28 | AC | 7,035 ms
230,912 KB |
testcase_29 | AC | 6,857 ms
233,728 KB |
testcase_30 | AC | 876 ms
180,312 KB |
testcase_31 | AC | 1,033 ms
181,276 KB |
testcase_32 | AC | 3,395 ms
178,596 KB |
testcase_33 | AC | 3,714 ms
211,012 KB |
testcase_34 | AC | 3,413 ms
208,072 KB |
testcase_35 | AC | 5,325 ms
208,976 KB |
testcase_36 | AC | 5,047 ms
201,552 KB |
ソースコード
MOD = 998244353 ntt_perm = [[0]] def ntt(A,g): logN = 0 while (1 << logN) < len(A) : logN += 1 N = 1 << logN A += [0] * (N-len(A)) while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]]) X = ntt_perm[logN] for i in range(N): if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i] i = 1 while i < N: q = pow(g,(MOD-1)//i//2,MOD) qj = 1 for j in range(0,N,i*2): qj = 1 for k in range(j,j+i): A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD qj = qj * q % MOD i *= 2 N = int(input()) Q = list(map(int,input().split())) E = [[] for i in range(N)] for i in range(N-1) : u,v = map(int,input().split()) E[u-1].append(v-1) E[v-1].append(u-1) cdep = [0] * N cp = [-1] * N cbfs = [0] * N def centroid_decomposition() : P = [-1] * N I = [0] for p in I : for e in E[p] : if P[p] == e : continue P[e] = p I.append(e) Z = [1] * N for p in I[-1:0:-1] : Z[P[p]] += Z[p] cdI = [(0,-1)] for s,par in cdI : while True : nx = -1 for e in E[s] : if Z[e] * 2 > Z[s] : nx = e if nx == -1 : break Z[s],Z[nx] = Z[s]-Z[nx],Z[s] s = nx cbfs.append(s) Z[s] = 0 if par != -1 : cdep[s],cp[s] = cdep[par]+1,par for e in E[s] : if Z[e] != 0 : cdI.append((e,s)) centroid_decomposition() NTTg = 3 invNTTg = pow(NTTg, MOD-2, MOD) max_ntt_size_log = 0 while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1 max_ntt_size_log += 1 max_ntt_size = 1 << max_ntt_size_log k0 = 1 for i in range(1, N+1) : k0 = k0 * i % MOD k0 = k0 * k0 % MOD inv_mod = [1] * (max_ntt_size+1) for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD C = [0] * (max_ntt_size+1) for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD nttC = [[] for i in range(max_ntt_size_log+1)] for d in range(max_ntt_size_log+1) : inv_ntt_size = pow(1<<d, MOD-2, MOD) nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)] ntt(nttC[d], NTTg) bfsbuf_dist = [0] * N bfsbuf_parent = [0] * N bfsbuf_I = [] def sigma_tree(s, dep) : global bfsbuf_dist global bfsbuf_parent global bfsbuf_I if cdep[s] < dep : return [0] bfsbuf_dist[s] = 0 bfsbuf_parent[s] = -1 I = [s] dfreq = [0] for p in I : d = bfsbuf_dist[p] if len(dfreq) <= d : dfreq.append(0) dfreq[d] = (dfreq[d] + Q[p]) % MOD for e in E[p] : if bfsbuf_parent[p] == e : continue if cdep[e] < dep : continue bfsbuf_parent[e] = p I.append(e) bfsbuf_dist[e] = d + 1 bfsbuf_I = I if len(dfreq) == 1 : return [dfreq[0]*C[t]%MOD for t in range(3)] if len(dfreq) == 2 : return [(dfreq[0]*C[t]+dfreq[1]*C[t+1])%MOD for t in range(4)] Z = 1 d = 0 while Z < len(dfreq) + 2 : Z,d = (Z*2,d+1) res = [0] * Z*2 for i in range(len(dfreq)) : res[Z-i] = dfreq[i] ntt(res, NTTg) for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD ntt(res, invNTTg) res = res[Z:] return res ans = [0] * N for s in range(N) : dep_s = cdep[s] sigma_s = sigma_tree(s, dep_s) for nx in E[s] : if cdep[nx] <= dep_s : continue sigma_nx = sigma_tree(nx, dep_s+1) for p in bfsbuf_I : d = bfsbuf_dist[p] + 1 ans[p] += sigma_s[d] - sigma_nx[d+1] ans[s] += sigma_s[0] for p in range(N) : ans[p] %= MOD print("\n".join(map(str, ans)))