結果

問題 No.1796 木上のクーロン
ユーザー 👑 NachiaNachia
提出日時 2021-12-18 17:12:51
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 8,327 ms / 10,000 ms
コード長 3,764 bytes
コンパイル時間 261 ms
コンパイル使用メモリ 81,672 KB
実行使用メモリ 232,060 KB
最終ジャッジ日時 2023-10-20 03:01:05
合計ジャッジ時間 66,392 ms
ジャッジサーバーID
(参考情報)
judge14 / judge15
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
53,508 KB
testcase_01 AC 39 ms
53,508 KB
testcase_02 AC 39 ms
53,508 KB
testcase_03 AC 38 ms
53,508 KB
testcase_04 AC 39 ms
53,508 KB
testcase_05 AC 39 ms
53,508 KB
testcase_06 AC 38 ms
53,508 KB
testcase_07 AC 38 ms
53,508 KB
testcase_08 AC 98 ms
75,952 KB
testcase_09 AC 104 ms
76,196 KB
testcase_10 AC 102 ms
76,156 KB
testcase_11 AC 99 ms
76,196 KB
testcase_12 AC 116 ms
76,804 KB
testcase_13 AC 89 ms
75,896 KB
testcase_14 AC 104 ms
76,192 KB
testcase_15 AC 130 ms
77,112 KB
testcase_16 AC 131 ms
77,180 KB
testcase_17 AC 71 ms
72,512 KB
testcase_18 AC 74 ms
72,772 KB
testcase_19 AC 115 ms
76,280 KB
testcase_20 AC 757 ms
100,620 KB
testcase_21 AC 823 ms
100,204 KB
testcase_22 AC 1,784 ms
125,152 KB
testcase_23 AC 1,667 ms
124,224 KB
testcase_24 AC 2,471 ms
151,796 KB
testcase_25 AC 2,804 ms
153,564 KB
testcase_26 AC 3,684 ms
181,208 KB
testcase_27 AC 3,351 ms
180,360 KB
testcase_28 AC 8,327 ms
230,112 KB
testcase_29 AC 8,264 ms
232,060 KB
testcase_30 AC 1,077 ms
179,792 KB
testcase_31 AC 1,249 ms
180,516 KB
testcase_32 AC 4,421 ms
177,852 KB
testcase_33 AC 4,426 ms
209,632 KB
testcase_34 AC 4,084 ms
207,112 KB
testcase_35 AC 6,351 ms
208,612 KB
testcase_36 AC 6,206 ms
200,520 KB
権限があれば一括ダウンロードができます

ソースコード

diff #


MOD = 998244353
ntt_perm = [[0]]

def ntt(A,g):
    logN = 0
    while (1 << logN) < len(A) : logN += 1
    N = 1 << logN
    A += [0] * (N-len(A))
    while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]])
    X = ntt_perm[logN]
    for i in range(N):
        if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i]
    i = 1
    while i < N:
        q = pow(g,(MOD-1)//i//2,MOD)
        qj = 1
        for j in range(0,N,i*2):
            qj = 1
            for k in range(j,j+i):
                A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD
                qj = qj * q % MOD
        i *= 2


N = int(input())
Q = list(map(int,input().split()))
E = [[] for i in range(N)]

for i in range(N-1) :
    u,v = map(int,input().split())
    E[u-1].append(v-1)
    E[v-1].append(u-1)

cdep = [0] * N
cp = [-1] * N
cbfs = [0] * N

def centroid_decomposition() :
    P = [-1] * N
    I = [0]
    for p in I :
        for e in E[p] :
            if P[p] == e : continue
            P[e] = p
            I.append(e)
    Z = [1] * N
    for p in I[-1:0:-1] :
        Z[P[p]] += Z[p]
    cdI = [(0,-1)]
    for s,par in cdI :
        while True :
            nx = -1
            for e in E[s] :
                if Z[e] * 2 > Z[s] : nx = e
            if nx == -1 : break
            Z[s],Z[nx] = Z[s]-Z[nx],Z[s]
            s = nx
        cbfs.append(s)
        Z[s] = 0
        if par != -1 : cdep[s],cp[s] = cdep[par]+1,par
        for e in E[s] :
            if Z[e] != 0 :
                cdI.append((e,s))

centroid_decomposition()


NTTg = 3
invNTTg = pow(NTTg, MOD-2, MOD)

max_ntt_size_log = 0
while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1
max_ntt_size_log += 1
max_ntt_size = 1 << max_ntt_size_log

k0 = 1
for i in range(1, N+1) : k0 = k0 * i % MOD
k0 = k0 * k0 % MOD

inv_mod = [1] * (max_ntt_size+1)
for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD

C = [0] * (max_ntt_size+1)
for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD

nttC = [[] for i in range(max_ntt_size_log+1)]
for d in range(max_ntt_size_log+1) :
    inv_ntt_size = pow(1<<d, MOD-2, MOD)
    nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)]
    ntt(nttC[d], NTTg)



bfsbuf_dist = [0] * N
bfsbuf_parent = [0] * N
bfsbuf_I = []

def sigma_tree(s, dep) :
    global bfsbuf_dist
    global bfsbuf_parent
    global bfsbuf_I
    if cdep[s] < dep : return [0]
    bfsbuf_dist[s] = 0
    bfsbuf_parent[s] = -1
    I = [s]
    dfreq = [0]
    for p in I :
        d = bfsbuf_dist[p]
        if len(dfreq) <= d : dfreq.append(0)
        dfreq[d] = (dfreq[d] + Q[p]) % MOD
        for e in E[p] :
            if bfsbuf_parent[p] == e : continue
            if cdep[e] < dep : continue
            bfsbuf_parent[e] = p
            I.append(e)
            bfsbuf_dist[e] = d + 1
    bfsbuf_I = I
    if len(dfreq) == 1 :
        return [dfreq[0]*C[t]%MOD for t in range(3)]
    if len(dfreq) == 2 :
        return [(dfreq[0]*C[t]+dfreq[1]*C[t+1])%MOD for t in range(4)]
    Z = 1
    d = 0
    while Z < len(dfreq) + 2 : Z,d = (Z*2,d+1)
    res = [0] * Z*2
    for i in range(len(dfreq)) : res[Z-i] = dfreq[i]
    ntt(res, NTTg)
    for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD
    ntt(res, invNTTg)
    res = res[Z:]
    return res



ans = [0] * N

for s in range(N) :
    dep_s = cdep[s]
    sigma_s = sigma_tree(s, dep_s)
    for nx in E[s] :
        if cdep[nx] <= dep_s : continue
        sigma_nx = sigma_tree(nx, dep_s+1)
        for p in bfsbuf_I :
            d = bfsbuf_dist[p] + 1
            ans[p] += sigma_s[d] - sigma_nx[d+1]
    ans[s] += sigma_s[0]

for p in range(N) : ans[p] %= MOD
print("\n".join(map(str, ans)))
0