結果

問題 No.1796 木上のクーロン
ユーザー 👑 NachiaNachia
提出日時 2021-12-18 17:12:51
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 7,035 ms / 10,000 ms
コード長 3,764 bytes
コンパイル時間 178 ms
コンパイル使用メモリ 82,404 KB
実行使用メモリ 233,728 KB
最終ジャッジ日時 2024-09-19 22:46:44
合計ジャッジ時間 55,252 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 37 ms
53,940 KB
testcase_01 AC 37 ms
53,440 KB
testcase_02 AC 37 ms
54,580 KB
testcase_03 AC 36 ms
54,468 KB
testcase_04 AC 37 ms
54,048 KB
testcase_05 AC 35 ms
53,980 KB
testcase_06 AC 34 ms
53,264 KB
testcase_07 AC 33 ms
53,624 KB
testcase_08 AC 82 ms
76,536 KB
testcase_09 AC 98 ms
76,376 KB
testcase_10 AC 97 ms
76,364 KB
testcase_11 AC 88 ms
76,840 KB
testcase_12 AC 106 ms
77,648 KB
testcase_13 AC 80 ms
76,364 KB
testcase_14 AC 95 ms
76,460 KB
testcase_15 AC 121 ms
77,412 KB
testcase_16 AC 116 ms
77,484 KB
testcase_17 AC 63 ms
72,912 KB
testcase_18 AC 68 ms
73,624 KB
testcase_19 AC 104 ms
76,660 KB
testcase_20 AC 614 ms
100,940 KB
testcase_21 AC 652 ms
100,968 KB
testcase_22 AC 1,346 ms
125,788 KB
testcase_23 AC 1,288 ms
124,392 KB
testcase_24 AC 1,975 ms
153,144 KB
testcase_25 AC 2,209 ms
153,772 KB
testcase_26 AC 2,930 ms
181,512 KB
testcase_27 AC 2,650 ms
181,232 KB
testcase_28 AC 7,035 ms
230,912 KB
testcase_29 AC 6,857 ms
233,728 KB
testcase_30 AC 876 ms
180,312 KB
testcase_31 AC 1,033 ms
181,276 KB
testcase_32 AC 3,395 ms
178,596 KB
testcase_33 AC 3,714 ms
211,012 KB
testcase_34 AC 3,413 ms
208,072 KB
testcase_35 AC 5,325 ms
208,976 KB
testcase_36 AC 5,047 ms
201,552 KB
権限があれば一括ダウンロードができます

ソースコード

diff #


MOD = 998244353
ntt_perm = [[0]]

def ntt(A,g):
    logN = 0
    while (1 << logN) < len(A) : logN += 1
    N = 1 << logN
    A += [0] * (N-len(A))
    while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]])
    X = ntt_perm[logN]
    for i in range(N):
        if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i]
    i = 1
    while i < N:
        q = pow(g,(MOD-1)//i//2,MOD)
        qj = 1
        for j in range(0,N,i*2):
            qj = 1
            for k in range(j,j+i):
                A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD
                qj = qj * q % MOD
        i *= 2


N = int(input())
Q = list(map(int,input().split()))
E = [[] for i in range(N)]

for i in range(N-1) :
    u,v = map(int,input().split())
    E[u-1].append(v-1)
    E[v-1].append(u-1)

cdep = [0] * N
cp = [-1] * N
cbfs = [0] * N

def centroid_decomposition() :
    P = [-1] * N
    I = [0]
    for p in I :
        for e in E[p] :
            if P[p] == e : continue
            P[e] = p
            I.append(e)
    Z = [1] * N
    for p in I[-1:0:-1] :
        Z[P[p]] += Z[p]
    cdI = [(0,-1)]
    for s,par in cdI :
        while True :
            nx = -1
            for e in E[s] :
                if Z[e] * 2 > Z[s] : nx = e
            if nx == -1 : break
            Z[s],Z[nx] = Z[s]-Z[nx],Z[s]
            s = nx
        cbfs.append(s)
        Z[s] = 0
        if par != -1 : cdep[s],cp[s] = cdep[par]+1,par
        for e in E[s] :
            if Z[e] != 0 :
                cdI.append((e,s))

centroid_decomposition()


NTTg = 3
invNTTg = pow(NTTg, MOD-2, MOD)

max_ntt_size_log = 0
while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1
max_ntt_size_log += 1
max_ntt_size = 1 << max_ntt_size_log

k0 = 1
for i in range(1, N+1) : k0 = k0 * i % MOD
k0 = k0 * k0 % MOD

inv_mod = [1] * (max_ntt_size+1)
for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD

C = [0] * (max_ntt_size+1)
for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD

nttC = [[] for i in range(max_ntt_size_log+1)]
for d in range(max_ntt_size_log+1) :
    inv_ntt_size = pow(1<<d, MOD-2, MOD)
    nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)]
    ntt(nttC[d], NTTg)



bfsbuf_dist = [0] * N
bfsbuf_parent = [0] * N
bfsbuf_I = []

def sigma_tree(s, dep) :
    global bfsbuf_dist
    global bfsbuf_parent
    global bfsbuf_I
    if cdep[s] < dep : return [0]
    bfsbuf_dist[s] = 0
    bfsbuf_parent[s] = -1
    I = [s]
    dfreq = [0]
    for p in I :
        d = bfsbuf_dist[p]
        if len(dfreq) <= d : dfreq.append(0)
        dfreq[d] = (dfreq[d] + Q[p]) % MOD
        for e in E[p] :
            if bfsbuf_parent[p] == e : continue
            if cdep[e] < dep : continue
            bfsbuf_parent[e] = p
            I.append(e)
            bfsbuf_dist[e] = d + 1
    bfsbuf_I = I
    if len(dfreq) == 1 :
        return [dfreq[0]*C[t]%MOD for t in range(3)]
    if len(dfreq) == 2 :
        return [(dfreq[0]*C[t]+dfreq[1]*C[t+1])%MOD for t in range(4)]
    Z = 1
    d = 0
    while Z < len(dfreq) + 2 : Z,d = (Z*2,d+1)
    res = [0] * Z*2
    for i in range(len(dfreq)) : res[Z-i] = dfreq[i]
    ntt(res, NTTg)
    for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD
    ntt(res, invNTTg)
    res = res[Z:]
    return res



ans = [0] * N

for s in range(N) :
    dep_s = cdep[s]
    sigma_s = sigma_tree(s, dep_s)
    for nx in E[s] :
        if cdep[nx] <= dep_s : continue
        sigma_nx = sigma_tree(nx, dep_s+1)
        for p in bfsbuf_I :
            d = bfsbuf_dist[p] + 1
            ans[p] += sigma_s[d] - sigma_nx[d+1]
    ans[s] += sigma_s[0]

for p in range(N) : ans[p] %= MOD
print("\n".join(map(str, ans)))
0