結果

問題 No.1796 木上のクーロン
ユーザー 👑 NachiaNachia
提出日時 2021-12-18 16:35:47
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 9,797 ms / 10,000 ms
コード長 3,585 bytes
コンパイル時間 278 ms
コンパイル使用メモリ 81,920 KB
実行使用メモリ 265,564 KB
最終ジャッジ日時 2024-09-19 22:48:56
合計ジャッジ時間 51,179 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 35 ms
54,288 KB
testcase_01 AC 37 ms
53,848 KB
testcase_02 AC 44 ms
61,464 KB
testcase_03 AC 38 ms
54,152 KB
testcase_04 AC 37 ms
54,704 KB
testcase_05 AC 38 ms
53,836 KB
testcase_06 AC 37 ms
53,796 KB
testcase_07 AC 36 ms
53,512 KB
testcase_08 AC 98 ms
76,720 KB
testcase_09 AC 109 ms
76,604 KB
testcase_10 AC 113 ms
77,124 KB
testcase_11 AC 102 ms
76,600 KB
testcase_12 AC 122 ms
77,460 KB
testcase_13 AC 96 ms
76,532 KB
testcase_14 AC 102 ms
76,824 KB
testcase_15 AC 124 ms
77,560 KB
testcase_16 AC 127 ms
77,384 KB
testcase_17 AC 117 ms
77,012 KB
testcase_18 AC 100 ms
76,752 KB
testcase_19 AC 121 ms
77,360 KB
testcase_20 AC 1,499 ms
107,500 KB
testcase_21 AC 1,584 ms
106,128 KB
testcase_22 AC 3,350 ms
142,132 KB
testcase_23 AC 3,179 ms
138,984 KB
testcase_24 AC 5,158 ms
179,024 KB
testcase_25 AC 4,679 ms
181,336 KB
testcase_26 AC 6,266 ms
214,824 KB
testcase_27 AC 6,527 ms
228,140 KB
testcase_28 AC 9,797 ms
265,564 KB
testcase_29 AC 8,651 ms
251,324 KB
testcase_30 AC 2,901 ms
183,452 KB
testcase_31 AC 2,938 ms
189,244 KB
testcase_32 AC 9,545 ms
255,304 KB
testcase_33 AC 6,204 ms
232,564 KB
testcase_34 AC 5,647 ms
224,388 KB
testcase_35 AC 9,373 ms
258,720 KB
testcase_36 AC 9,120 ms
259,268 KB
権限があれば一括ダウンロードができます

ソースコード

diff #


MOD = 998244353
ntt_perm = [[0]]
 
def ntt(A,g):
    logN = 0
    while (1 << logN) < len(A) : logN += 1
    N = 1 << logN
    A += [0] * (N-len(A))
    while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]])
    X = ntt_perm[logN]
    for i in range(N):
        if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i]
    i = 1
    while i < N:
        q = pow(g,(MOD-1)//i//2,MOD)
        qj = 1
        for j in range(0,N,i*2):
            qj = 1
            for k in range(j,j+i):
                A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD
                qj = qj * q % MOD
        i *= 2

N = int(input())
Q = list(map(int,input().split()))
E = [[] for i in range(N)]

for i in range(N-1) :
    u,v = map(int,input().split())
    E[u-1].append(v-1)
    E[v-1].append(u-1)

cdep = [0] * N
cp = [-1] * N
cbfs = [0] * N

def centroid_decomposition() :
    P = [-1] * N
    I = [0]
    for p in I :
        for e in E[p] :
            if P[p] == e : continue
            P[e] = p
            I.append(e)
    Z = [1] * N
    for p in I[-1:0:-1] :
        Z[P[p]] += Z[p]
    cdI = [(0,-1)]
    for s,par in cdI :
        while True :
            nx = -1
            for e in E[s] :
                if Z[e] * 2 > Z[s] : nx = e
            if nx == -1 : break
            Z[s],Z[nx] = Z[s]-Z[nx],Z[s]
            s = nx
        cbfs.append(s)
        Z[s] = 0
        if par != -1 : cdep[s],cp[s] = cdep[par]+1,par
        for e in E[s] :
            if Z[e] != 0 :
                cdI.append((e,s))

centroid_decomposition()


NTTg = 3
invNTTg = pow(NTTg, MOD-2, MOD)

max_ntt_size_log = 0
while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1
max_ntt_size_log += 1
max_ntt_size = 1 << max_ntt_size_log

k0 = 1
for i in range(1, N+1) : k0 = k0 * i % MOD
k0 = k0 * k0 % MOD

inv_mod = [1] * (max_ntt_size+1)
for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD

C = [0] * (max_ntt_size+1)
for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD

nttC = [[] for i in range(max_ntt_size_log+1)]
for d in range(max_ntt_size_log+1) :
    inv_ntt_size = pow(1<<d, MOD-2, MOD)
    nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)]
    ntt(nttC[d], NTTg)



bfsbuf_dist = [0] * N
bfsbuf_parent = [0] * N
bfsbuf_I = []
def sigma_tree(s, dep) :
    global bfsbuf_dist
    global bfsbuf_parent
    global bfsbuf_I
    if cdep[s] < dep : return [0]
    bfsbuf_dist[s] = 0
    bfsbuf_parent[s] = -1
    I = [s]
    dfreq = [0]
    for p in I :
        d = bfsbuf_dist[p]
        if len(dfreq) <= d : dfreq.append(0)
        dfreq[d] = (dfreq[d] + Q[p]) % MOD
        for e in E[p] :
            if bfsbuf_parent[p] == e : continue
            if cdep[e] < dep : continue
            bfsbuf_parent[e] = p
            I.append(e)
            bfsbuf_dist[e] = d + 1
    Z = 1
    d = 0
    while Z < len(I) + 2 : Z,d = (Z*2,d+1)
    res = [0] * Z*2
    for i in range(len(dfreq)) : res[Z-i] = dfreq[i]
    ntt(res, NTTg)
    for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD
    ntt(res, invNTTg)
    bfsbuf_I = I
    res = res[Z:]
    return res



ans = [0] * N

for s in range(N) :
    dep_s = cdep[s]
    sigma_s = sigma_tree(s, dep_s)
    for nx in E[s] :
        if cdep[nx] <= dep_s : continue
        sigma_nx = sigma_tree(nx, dep_s+1)
        for p in bfsbuf_I :
            d = bfsbuf_dist[p] + 1
            ans[p] += sigma_s[d] - sigma_nx[d+1]
    ans[s] += sigma_s[0]

for p in range(N) : ans[p] %= MOD
print("\n".join(map(str, ans)))
0