結果

問題 No.1796 木上のクーロン
ユーザー 👑 NachiaNachia
提出日時 2021-12-18 16:35:47
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,585 bytes
コンパイル時間 184 ms
コンパイル使用メモリ 81,848 KB
実行使用メモリ 264,224 KB
最終ジャッジ日時 2023-10-20 03:02:42
合計ジャッジ時間 52,792 ms
ジャッジサーバーID
(参考情報)
judge11 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 44 ms
53,648 KB
testcase_01 AC 40 ms
53,448 KB
testcase_02 AC 51 ms
61,288 KB
testcase_03 AC 40 ms
53,448 KB
testcase_04 AC 44 ms
53,448 KB
testcase_05 AC 42 ms
53,448 KB
testcase_06 AC 41 ms
53,448 KB
testcase_07 AC 40 ms
53,448 KB
testcase_08 AC 109 ms
75,896 KB
testcase_09 AC 118 ms
76,068 KB
testcase_10 AC 122 ms
76,624 KB
testcase_11 AC 114 ms
76,076 KB
testcase_12 AC 140 ms
77,048 KB
testcase_13 AC 111 ms
76,000 KB
testcase_14 AC 115 ms
76,044 KB
testcase_15 AC 135 ms
76,924 KB
testcase_16 AC 143 ms
76,992 KB
testcase_17 AC 133 ms
76,656 KB
testcase_18 AC 110 ms
76,068 KB
testcase_19 AC 137 ms
76,916 KB
testcase_20 AC 1,754 ms
106,868 KB
testcase_21 AC 1,834 ms
105,792 KB
testcase_22 AC 3,830 ms
141,396 KB
testcase_23 AC 3,659 ms
139,372 KB
testcase_24 AC 5,798 ms
178,592 KB
testcase_25 AC 5,423 ms
180,776 KB
testcase_26 AC 7,180 ms
214,636 KB
testcase_27 AC 7,469 ms
227,680 KB
testcase_28 TLE -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #


MOD = 998244353
ntt_perm = [[0]]
 
def ntt(A,g):
    logN = 0
    while (1 << logN) < len(A) : logN += 1
    N = 1 << logN
    A += [0] * (N-len(A))
    while len(ntt_perm) < logN + 1: ntt_perm.append([2*x for x in ntt_perm[-1]] + [2*x+1 for x in ntt_perm[-1]])
    X = ntt_perm[logN]
    for i in range(N):
        if i < X[i]: A[i],A[X[i]] = A[X[i]],A[i]
    i = 1
    while i < N:
        q = pow(g,(MOD-1)//i//2,MOD)
        qj = 1
        for j in range(0,N,i*2):
            qj = 1
            for k in range(j,j+i):
                A[k], A[k+i] = (A[k] + A[k+i] * qj) % MOD , (A[k] - A[k+i] * qj) % MOD
                qj = qj * q % MOD
        i *= 2

N = int(input())
Q = list(map(int,input().split()))
E = [[] for i in range(N)]

for i in range(N-1) :
    u,v = map(int,input().split())
    E[u-1].append(v-1)
    E[v-1].append(u-1)

cdep = [0] * N
cp = [-1] * N
cbfs = [0] * N

def centroid_decomposition() :
    P = [-1] * N
    I = [0]
    for p in I :
        for e in E[p] :
            if P[p] == e : continue
            P[e] = p
            I.append(e)
    Z = [1] * N
    for p in I[-1:0:-1] :
        Z[P[p]] += Z[p]
    cdI = [(0,-1)]
    for s,par in cdI :
        while True :
            nx = -1
            for e in E[s] :
                if Z[e] * 2 > Z[s] : nx = e
            if nx == -1 : break
            Z[s],Z[nx] = Z[s]-Z[nx],Z[s]
            s = nx
        cbfs.append(s)
        Z[s] = 0
        if par != -1 : cdep[s],cp[s] = cdep[par]+1,par
        for e in E[s] :
            if Z[e] != 0 :
                cdI.append((e,s))

centroid_decomposition()


NTTg = 3
invNTTg = pow(NTTg, MOD-2, MOD)

max_ntt_size_log = 0
while (1 << max_ntt_size_log) < N + 6 : max_ntt_size_log += 1
max_ntt_size_log += 1
max_ntt_size = 1 << max_ntt_size_log

k0 = 1
for i in range(1, N+1) : k0 = k0 * i % MOD
k0 = k0 * k0 % MOD

inv_mod = [1] * (max_ntt_size+1)
for i in range(2,max_ntt_size+1) : inv_mod[i] = MOD - MOD // i * inv_mod[MOD%i] % MOD

C = [0] * (max_ntt_size+1)
for i in range(max_ntt_size) : C[i] = k0 * inv_mod[i+1] % MOD * inv_mod[i+1] % MOD

nttC = [[] for i in range(max_ntt_size_log+1)]
for d in range(max_ntt_size_log+1) :
    inv_ntt_size = pow(1<<d, MOD-2, MOD)
    nttC[d] = [C[i] * inv_ntt_size % MOD for i in range(1<<d)]
    ntt(nttC[d], NTTg)



bfsbuf_dist = [0] * N
bfsbuf_parent = [0] * N
bfsbuf_I = []
def sigma_tree(s, dep) :
    global bfsbuf_dist
    global bfsbuf_parent
    global bfsbuf_I
    if cdep[s] < dep : return [0]
    bfsbuf_dist[s] = 0
    bfsbuf_parent[s] = -1
    I = [s]
    dfreq = [0]
    for p in I :
        d = bfsbuf_dist[p]
        if len(dfreq) <= d : dfreq.append(0)
        dfreq[d] = (dfreq[d] + Q[p]) % MOD
        for e in E[p] :
            if bfsbuf_parent[p] == e : continue
            if cdep[e] < dep : continue
            bfsbuf_parent[e] = p
            I.append(e)
            bfsbuf_dist[e] = d + 1
    Z = 1
    d = 0
    while Z < len(I) + 2 : Z,d = (Z*2,d+1)
    res = [0] * Z*2
    for i in range(len(dfreq)) : res[Z-i] = dfreq[i]
    ntt(res, NTTg)
    for i in range(Z*2) : res[i] = res[i] * nttC[d+1][i] % MOD
    ntt(res, invNTTg)
    bfsbuf_I = I
    res = res[Z:]
    return res



ans = [0] * N

for s in range(N) :
    dep_s = cdep[s]
    sigma_s = sigma_tree(s, dep_s)
    for nx in E[s] :
        if cdep[nx] <= dep_s : continue
        sigma_nx = sigma_tree(nx, dep_s+1)
        for p in bfsbuf_I :
            d = bfsbuf_dist[p] + 1
            ans[p] += sigma_s[d] - sigma_nx[d+1]
    ans[s] += sigma_s[0]

for p in range(N) : ans[p] %= MOD
print("\n".join(map(str, ans)))
0