結果

問題 No.2949 Product on Tree
ユーザー miya145592miya145592
提出日時 2024-10-26 18:24:04
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,315 bytes
コンパイル時間 467 ms
コンパイル使用メモリ 82,300 KB
実行使用メモリ 349,208 KB
最終ジャッジ日時 2024-10-26 18:25:01
合計ジャッジ時間 55,691 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 34 ms
53,384 KB
testcase_01 AC 35 ms
53,132 KB
testcase_02 AC 34 ms
53,012 KB
testcase_03 AC 812 ms
126,760 KB
testcase_04 AC 750 ms
125,220 KB
testcase_05 AC 793 ms
126,708 KB
testcase_06 AC 790 ms
126,728 KB
testcase_07 AC 751 ms
125,388 KB
testcase_08 AC 829 ms
127,108 KB
testcase_09 AC 818 ms
127,744 KB
testcase_10 AC 783 ms
130,612 KB
testcase_11 AC 1,415 ms
150,012 KB
testcase_12 AC 865 ms
162,032 KB
testcase_13 AC 987 ms
205,616 KB
testcase_14 AC 1,583 ms
287,500 KB
testcase_15 AC 1,806 ms
288,308 KB
testcase_16 AC 1,857 ms
284,920 KB
testcase_17 AC 1,207 ms
290,328 KB
testcase_18 AC 1,732 ms
289,804 KB
testcase_19 AC 1,313 ms
259,664 KB
testcase_20 AC 1,756 ms
284,140 KB
testcase_21 AC 1,770 ms
279,236 KB
testcase_22 AC 1,693 ms
292,780 KB
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 WA -
testcase_35 WA -
testcase_36 WA -
testcase_37 WA -
testcase_38 WA -
testcase_39 WA -
testcase_40 WA -
testcase_41 WA -
testcase_42 WA -
testcase_43 AC 343 ms
111,048 KB
testcase_44 AC 333 ms
111,360 KB
testcase_45 AC 414 ms
131,612 KB
testcase_46 AC 383 ms
116,212 KB
testcase_47 AC 284 ms
105,976 KB
testcase_48 AC 377 ms
116,696 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def dfs(v, p):
    ret = 0
    for nv in G[v]:
        if nv==p:
            continue
        dp[nv] = dp[v]*A[nv]%MOD
        ret += dfs(nv, v)
        ret %= MOD
    dp2[v] = (dp[v]+ret)%MOD
    return dp2[v]

def dfs2(v, p, tmp=1):
    dp3[v] = (dp2[v]-A[v])%MOD
    for nv in G[v]:
        if nv==p:
            continue
        par = dp2[v]
        cur = dp2[nv]
        if p==-1:
            ntmp = tmp
        else:
            ntmp = tmp*invA[p]%MOD
        cur2 = cur * ntmp % MOD
        dp2[nv] = (cur2*invA[v]%MOD + (par-cur2)*A[nv]%MOD )%MOD
        dp2[v] = (par-cur2)*A[nv]%MOD
        dfs2(nv, v, ntmp)
        dp2[v] = par
        dp2[nv] = cur

import sys
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
MOD = 998244353
N = int(input())
A = list(map(int, input().split()))
G = [[] for _ in range(N)]
for _ in range(N-1):
    u, v = map(int, input().split())
    u-=1
    v-=1
    G[u].append(v)
    G[v].append(u)
invA = []
for a in A:
    invA.append(pow(a, MOD-2, MOD))
dp = [0 for _ in range(N)]
dp2 = [0 for _ in range(N)]
dp[0] = A[0]
dp2[0] = dfs(0, -1)
#print(dp)
#print(dp2)

dp3 = [0 for _ in range(N)]
dfs2(0, -1)
#print(dp3)
ans = 0
for d in dp3:
    ans += d
    ans %= MOD
ans = ans*pow(2, MOD-2, MOD)%MOD
print(ans)
0