結果

問題 No.2949 Product on Tree
ユーザー miya145592miya145592
提出日時 2024-10-26 18:19:45
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,240 bytes
コンパイル時間 359 ms
コンパイル使用メモリ 82,444 KB
実行使用メモリ 143,636 KB
最終ジャッジ日時 2024-10-26 18:20:02
合計ジャッジ時間 14,011 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 35 ms
61,356 KB
testcase_01 AC 36 ms
53,020 KB
testcase_02 AC 36 ms
53,760 KB
testcase_03 TLE -
testcase_04 TLE -
testcase_05 TLE -
testcase_06 TLE -
testcase_07 TLE -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
testcase_40 -- -
testcase_41 -- -
testcase_42 -- -
testcase_43 -- -
testcase_44 -- -
testcase_45 -- -
testcase_46 -- -
testcase_47 -- -
testcase_48 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

def dfs(v, p):
    ret = 0
    for nv in G[v]:
        if nv==p:
            continue
        dp[nv] = dp[v]*A[nv]%MOD
        ret += dfs(nv, v)
        ret %= MOD
    dp2[v] = (dp[v]+ret)%MOD
    return dp2[v]

def dfs2(v, p, tmp=1):
    dp3[v] = (dp2[v]-A[v])%MOD
    for nv in G[v]:
        if nv==p:
            continue
        par = dp2[v]
        cur = dp2[nv]
        ntmp = tmp*A[p] if p!=-1 else tmp
        cur2 = cur * pow(ntmp, MOD-2, MOD) % MOD
        dp2[nv] = (cur2*pow(A[v], MOD-2, MOD)%MOD + (par-cur2)*A[nv]%MOD )%MOD
        dp2[v] = (par-cur2)*A[nv]%MOD
        dfs2(nv, v, ntmp)
        dp2[v] = par
        dp2[nv] = cur

import sys
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
MOD = 998244353
N = int(input())
A = list(map(int, input().split()))
G = [[] for _ in range(N)]
for _ in range(N-1):
    u, v = map(int, input().split())
    u-=1
    v-=1
    G[u].append(v)
    G[v].append(u)
dp = [0 for _ in range(N)]
dp2 = [0 for _ in range(N)]
dp[0] = A[0]
dp2[0] = dfs(0, -1)
#print(dp)
#print(dp2)

dp3 = [0 for _ in range(N)]
dfs2(0, -1)
#print(dp3)
ans = 0
for d in dp3:
    ans += d
    ans %= MOD
ans = ans*pow(2, MOD-2, MOD)%MOD
print(ans)
0