結果

問題 No.2949 Product on Tree
ユーザー miya145592
提出日時 2024-10-26 18:19:45
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,240 bytes
コンパイル時間 359 ms
コンパイル使用メモリ 82,444 KB
実行使用メモリ 143,636 KB
最終ジャッジ日時 2024-10-26 18:20:02
合計ジャッジ時間 14,011 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other TLE * 5 -- * 41
権限があれば一括ダウンロードができます

ソースコード

diff #

def dfs(v, p):
    ret = 0
    for nv in G[v]:
        if nv==p:
            continue
        dp[nv] = dp[v]*A[nv]%MOD
        ret += dfs(nv, v)
        ret %= MOD
    dp2[v] = (dp[v]+ret)%MOD
    return dp2[v]

def dfs2(v, p, tmp=1):
    dp3[v] = (dp2[v]-A[v])%MOD
    for nv in G[v]:
        if nv==p:
            continue
        par = dp2[v]
        cur = dp2[nv]
        ntmp = tmp*A[p] if p!=-1 else tmp
        cur2 = cur * pow(ntmp, MOD-2, MOD) % MOD
        dp2[nv] = (cur2*pow(A[v], MOD-2, MOD)%MOD + (par-cur2)*A[nv]%MOD )%MOD
        dp2[v] = (par-cur2)*A[nv]%MOD
        dfs2(nv, v, ntmp)
        dp2[v] = par
        dp2[nv] = cur

import sys
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
MOD = 998244353
N = int(input())
A = list(map(int, input().split()))
G = [[] for _ in range(N)]
for _ in range(N-1):
    u, v = map(int, input().split())
    u-=1
    v-=1
    G[u].append(v)
    G[v].append(u)
dp = [0 for _ in range(N)]
dp2 = [0 for _ in range(N)]
dp[0] = A[0]
dp2[0] = dfs(0, -1)
#print(dp)
#print(dp2)

dp3 = [0 for _ in range(N)]
dfs2(0, -1)
#print(dp3)
ans = 0
for d in dp3:
    ans += d
    ans %= MOD
ans = ans*pow(2, MOD-2, MOD)%MOD
print(ans)
0