結果
問題 | No.1796 木上のクーロン |
ユーザー | 👑 Nachia |
提出日時 | 2021-12-18 16:35:47 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 9,797 ms / 10,000 ms |
コード長 | 3,585 bytes |
コンパイル時間 | 278 ms |
コンパイル使用メモリ | 81,920 KB |
実行使用メモリ | 265,564 KB |
最終ジャッジ日時 | 2024-09-19 22:48:56 |
合計ジャッジ時間 | 51,179 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 35 ms
54,288 KB |
testcase_01 | AC | 37 ms
53,848 KB |
testcase_02 | AC | 44 ms
61,464 KB |
testcase_03 | AC | 38 ms
54,152 KB |
testcase_04 | AC | 37 ms
54,704 KB |
testcase_05 | AC | 38 ms
53,836 KB |
testcase_06 | AC | 37 ms
53,796 KB |
testcase_07 | AC | 36 ms
53,512 KB |
testcase_08 | AC | 98 ms
76,720 KB |
testcase_09 | AC | 109 ms
76,604 KB |
testcase_10 | AC | 113 ms
77,124 KB |
testcase_11 | AC | 102 ms
76,600 KB |
testcase_12 | AC | 122 ms
77,460 KB |
testcase_13 | AC | 96 ms
76,532 KB |
testcase_14 | AC | 102 ms
76,824 KB |
testcase_15 | AC | 124 ms
77,560 KB |
testcase_16 | AC | 127 ms
77,384 KB |
testcase_17 | AC | 117 ms
77,012 KB |
testcase_18 | AC | 100 ms
76,752 KB |
testcase_19 | AC | 121 ms
77,360 KB |
testcase_20 | AC | 1,499 ms
107,500 KB |
testcase_21 | AC | 1,584 ms
106,128 KB |
testcase_22 | AC | 3,350 ms
142,132 KB |
testcase_23 | AC | 3,179 ms
138,984 KB |
testcase_24 | AC | 5,158 ms
179,024 KB |
testcase_25 | AC | 4,679 ms
181,336 KB |
testcase_26 | AC | 6,266 ms
214,824 KB |
testcase_27 | AC | 6,527 ms
228,140 KB |
testcase_28 | AC | 9,797 ms
265,564 KB |
testcase_29 | AC | 8,651 ms
251,324 KB |
testcase_30 | AC | 2,901 ms
183,452 KB |
testcase_31 | AC | 2,938 ms
189,244 KB |
testcase_32 | AC | 9,545 ms
255,304 KB |
testcase_33 | AC | 6,204 ms
232,564 KB |
testcase_34 | AC | 5,647 ms
224,388 KB |
testcase_35 | AC | 9,373 ms
258,720 KB |
testcase_36 | AC | 9,120 ms
259,268 KB |
ソースコード
MOD = 998244353 ntt_perm = [[0]] def ntt(A,g): logN = 0 while (1 << logN) < len(A) : logN += 1 N = 1 << logN A += [0] * (N-len(A)) while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]]) X = ntt_perm[logN] for i in range(N): if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i] i = 1 while i < N: q = pow(g,(MOD-1)//i//2,MOD) qj = 1 for j in range(0,N,i*2): qj = 1 for k in range(j,j+i): A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD qj = qj * q % MOD i *= 2 N = int(input()) Q = list(map(int,input().split())) E = [[] for i in range(N)] for i in range(N-1) : u,v = map(int,input().split()) E[u-1].append(v-1) E[v-1].append(u-1) cdep = [0] * N cp = [-1] * N cbfs = [0] * N def centroid_decomposition() : P = [-1] * N I = [0] for p in I : for e in E[p] : if P[p] == e : continue P[e] = p I.append(e) Z = [1] * N for p in I[-1:0:-1] : Z[P[p]] += Z[p] cdI = [(0,-1)] for s,par in cdI : while True : nx = -1 for e in E[s] : if Z[e] * 2 > Z[s] : nx = e if nx == -1 : break Z[s],Z[nx] = Z[s]-Z[nx],Z[s] s = nx cbfs.append(s) Z[s] = 0 if par != -1 : cdep[s],cp[s] = cdep[par]+1,par for e in E[s] : if Z[e] != 0 : cdI.append((e,s)) centroid_decomposition() NTTg = 3 invNTTg = pow(NTTg, MOD-2, MOD) max_ntt_size_log = 0 while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1 max_ntt_size_log += 1 max_ntt_size = 1 << max_ntt_size_log k0 = 1 for i in range(1, N+1) : k0 = k0 * i % MOD k0 = k0 * k0 % MOD inv_mod = [1] * (max_ntt_size+1) for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD C = [0] * (max_ntt_size+1) for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD nttC = [[] for i in range(max_ntt_size_log+1)] for d in range(max_ntt_size_log+1) : inv_ntt_size = pow(1<<d, MOD-2, MOD) nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)] ntt(nttC[d], NTTg) bfsbuf_dist = [0] * N bfsbuf_parent = [0] * N bfsbuf_I = [] def sigma_tree(s, dep) : global bfsbuf_dist global bfsbuf_parent global bfsbuf_I if cdep[s] < dep : return [0] bfsbuf_dist[s] = 0 bfsbuf_parent[s] = -1 I = [s] dfreq = [0] for p in I : d = bfsbuf_dist[p] if len(dfreq) <= d : dfreq.append(0) dfreq[d] = (dfreq[d] + Q[p]) % MOD for e in E[p] : if bfsbuf_parent[p] == e : continue if cdep[e] < dep : continue bfsbuf_parent[e] = p I.append(e) bfsbuf_dist[e] = d + 1 Z = 1 d = 0 while Z < len(I) + 2 : Z,d = (Z*2,d+1) res = [0] * Z*2 for i in range(len(dfreq)) : res[Z-i] = dfreq[i] ntt(res, NTTg) for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD ntt(res, invNTTg) bfsbuf_I = I res = res[Z:] return res ans = [0] * N for s in range(N) : dep_s = cdep[s] sigma_s = sigma_tree(s, dep_s) for nx in E[s] : if cdep[nx] <= dep_s : continue sigma_nx = sigma_tree(nx, dep_s+1) for p in bfsbuf_I : d = bfsbuf_dist[p] + 1 ans[p] += sigma_s[d] - sigma_nx[d+1] ans[s] += sigma_s[0] for p in range(N) : ans[p] %= MOD print("\n".join(map(str, ans)))