結果
問題 | No.1796 木上のクーロン |
ユーザー |
👑 ![]() |
提出日時 | 2021-12-18 17:12:51 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 7,035 ms / 10,000 ms |
コード長 | 3,764 bytes |
コンパイル時間 | 178 ms |
コンパイル使用メモリ | 82,404 KB |
実行使用メモリ | 233,728 KB |
最終ジャッジ日時 | 2024-09-19 22:46:44 |
合計ジャッジ時間 | 55,252 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 34 |
ソースコード
MOD = 998244353ntt_perm = [[0]]def ntt(A,g):logN = 0while (1 << logN) < len(A) : logN += 1N = 1 << logNA += [0] * (N-len(A))while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]])X = ntt_perm[logN]for i in range(N):if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i]i = 1while i < N:q = pow(g,(MOD-1)//i//2,MOD)qj = 1for j in range(0,N,i*2):qj = 1for k in range(j,j+i):A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MODqj = qj * q % MODi *= 2N = int(input())Q = list(map(int,input().split()))E = [[] for i in range(N)]for i in range(N-1) :u,v = map(int,input().split())E[u-1].append(v-1)E[v-1].append(u-1)cdep = [0] * Ncp = [-1] * Ncbfs = [0] * Ndef centroid_decomposition() :P = [-1] * NI = [0]for p in I :for e in E[p] :if P[p] == e : continueP[e] = pI.append(e)Z = [1] * Nfor p in I[-1:0:-1] :Z[P[p]] += Z[p]cdI = [(0,-1)]for s,par in cdI :while True :nx = -1for e in E[s] :if Z[e] * 2 > Z[s] : nx = eif nx == -1 : breakZ[s],Z[nx] = Z[s]-Z[nx],Z[s]s = nxcbfs.append(s)Z[s] = 0if par != -1 : cdep[s],cp[s] = cdep[par]+1,parfor e in E[s] :if Z[e] != 0 :cdI.append((e,s))centroid_decomposition()NTTg = 3invNTTg = pow(NTTg, MOD-2, MOD)max_ntt_size_log = 0while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1max_ntt_size_log += 1max_ntt_size = 1 << max_ntt_size_logk0 = 1for i in range(1, N+1) : k0 = k0 * i % MODk0 = k0 * k0 % MODinv_mod = [1] * (max_ntt_size+1)for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MODC = [0] * (max_ntt_size+1)for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MODnttC = [[] for i in range(max_ntt_size_log+1)]for d in range(max_ntt_size_log+1) :inv_ntt_size = pow(1<<d, MOD-2, MOD)nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)]ntt(nttC[d], NTTg)bfsbuf_dist = [0] * Nbfsbuf_parent = [0] * Nbfsbuf_I = []def sigma_tree(s, dep) :global bfsbuf_distglobal bfsbuf_parentglobal bfsbuf_Iif cdep[s] < dep : return [0]bfsbuf_dist[s] = 0bfsbuf_parent[s] = -1I = [s]dfreq = [0]for p in I :d = bfsbuf_dist[p]if len(dfreq) <= d : dfreq.append(0)dfreq[d] = (dfreq[d] + Q[p]) % MODfor e in E[p] :if bfsbuf_parent[p] == e : continueif cdep[e] < dep : continuebfsbuf_parent[e] = pI.append(e)bfsbuf_dist[e] = d + 1bfsbuf_I = Iif len(dfreq) == 1 :return [dfreq[0]*C[t]%MOD for t in range(3)]if len(dfreq) == 2 :return [(dfreq[0]*C[t]+dfreq[1]*C[t+1])%MOD for t in range(4)]Z = 1d = 0while Z < len(dfreq) + 2 : Z,d = (Z*2,d+1)res = [0] * Z*2for i in range(len(dfreq)) : res[Z-i] = dfreq[i]ntt(res, NTTg)for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MODntt(res, invNTTg)res = res[Z:]return resans = [0] * Nfor s in range(N) :dep_s = cdep[s]sigma_s = sigma_tree(s, dep_s)for nx in E[s] :if cdep[nx] <= dep_s : continuesigma_nx = sigma_tree(nx, dep_s+1)for p in bfsbuf_I :d = bfsbuf_dist[p] + 1ans[p] += sigma_s[d] - sigma_nx[d+1]ans[s] += sigma_s[0]for p in range(N) : ans[p] %= MODprint("\n".join(map(str, ans)))