結果

問題 No.2949 Product on Tree
ユーザー nouka28
提出日時 2024-06-21 15:19:08
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,996 ms / 2,000 ms
コード長 1,001 bytes
コンパイル時間 519 ms
コンパイル使用メモリ 82,544 KB
実行使用メモリ 401,676 KB
最終ジャッジ日時 2024-09-23 07:20:53
合計ジャッジ時間 51,190 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 46
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(2*10**5)

def fast(N,A,G):
    mod=998244353

    def dfs(p,prev=-1):
        sm,sm2,ans=0,0,0
        for e in G[p]:
            if e==prev:continue
            esm,eans=dfs(e,p)
            sm=(sm+esm)%mod
            sm2=(sm2+esm**2)%mod
            ans=(ans+eans)%mod
        ans=(ans+A[p]*(sm**2-sm2)*pow(2,mod-2,mod)+A[p]*sm)%mod
        return ((sm+1)*A[p])%mod,ans

    return dfs(0)[1]

def naive(N,A,G):
    
    mod=998244353
    def dfs(p,prod,par,prev=-1)->int:
        ans=0
        prod=(prod*A[p])%mod
        if p<par:
            ans=(ans+prod)%mod
        for e in G[p]:
            if e==prev:continue
            ans=(ans+dfs(e,prod,par,p))%mod
        return ans
    ans=0
    for i in range(N):
        ans=(ans+dfs(i,1,i))%mod
    return ans

N=int(input())
A=list(map(int,input().split()))
G=[[]for i in range(N)]
for i in range(N-1):
    u,v=map(int,input().split())
    u-=1
    v-=1
    G[u].append(v)
    G[v].append(u)

print(fast(N,A,G))
0