結果
問題 | No.2949 Product on Tree |
ユーザー |
|
提出日時 | 2024-10-26 18:24:04 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 1,315 bytes |
コンパイル時間 | 467 ms |
コンパイル使用メモリ | 82,300 KB |
実行使用メモリ | 349,208 KB |
最終ジャッジ日時 | 2024-10-26 18:25:01 |
合計ジャッジ時間 | 55,691 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 26 WA * 20 |
ソースコード
def dfs(v, p):ret = 0for nv in G[v]:if nv==p:continuedp[nv] = dp[v]*A[nv]%MODret += dfs(nv, v)ret %= MODdp2[v] = (dp[v]+ret)%MODreturn dp2[v]def dfs2(v, p, tmp=1):dp3[v] = (dp2[v]-A[v])%MODfor nv in G[v]:if nv==p:continuepar = dp2[v]cur = dp2[nv]if p==-1:ntmp = tmpelse:ntmp = tmp*invA[p]%MODcur2 = cur * ntmp % MODdp2[nv] = (cur2*invA[v]%MOD + (par-cur2)*A[nv]%MOD )%MODdp2[v] = (par-cur2)*A[nv]%MODdfs2(nv, v, ntmp)dp2[v] = pardp2[nv] = curimport sysimport pypyjitpypyjit.set_param('max_unroll_recursion=-1')sys.setrecursionlimit(10**6)input = sys.stdin.readlineMOD = 998244353N = int(input())A = list(map(int, input().split()))G = [[] for _ in range(N)]for _ in range(N-1):u, v = map(int, input().split())u-=1v-=1G[u].append(v)G[v].append(u)invA = []for a in A:invA.append(pow(a, MOD-2, MOD))dp = [0 for _ in range(N)]dp2 = [0 for _ in range(N)]dp[0] = A[0]dp2[0] = dfs(0, -1)#print(dp)#print(dp2)dp3 = [0 for _ in range(N)]dfs2(0, -1)#print(dp3)ans = 0for d in dp3:ans += dans %= MODans = ans*pow(2, MOD-2, MOD)%MODprint(ans)