import sys sys.setrecursionlimit(2*10**5) def fast(N,A,G): mod=998244353 def dfs(p,prev=-1): sm,sm2,ans=0,0,0 for e in G[p]: if e==prev:continue esm,eans=dfs(e,p) sm=(sm+esm)%mod sm2=(sm2+esm**2)%mod ans=(ans+eans)%mod ans=(ans+A[p]*(sm**2-sm2)*pow(2,mod-2,mod)+A[p]*sm)%mod return ((sm+1)*A[p])%mod,ans return dfs(0)[1] def naive(N,A,G): mod=998244353 def dfs(p,prod,par,prev=-1)->int: ans=0 prod=(prod*A[p])%mod if p