結果
問題 | No.2949 Product on Tree |
ユーザー | liveworldlike |
提出日時 | 2024-10-25 23:02:15 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,474 ms / 2,000 ms |
コード長 | 2,094 bytes |
コンパイル時間 | 1,477 ms |
コンパイル使用メモリ | 89,604 KB |
実行使用メモリ | 107,264 KB |
最終ジャッジ日時 | 2024-10-25 23:03:28 |
合計ジャッジ時間 | 55,911 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 46 |
ソースコード
#pragma GCC optimize("O3") #include <map> #include <stdio.h> #include <stdlib.h> #include <vector> using namespace std; typedef int64_t ll; const ll MOD = 1000000007; const ll MOD2 = 998244353; const ll mod = MOD2; constexpr ll fastpow(ll x, ll p) { ll res = 1; while (p) { if (p & 1) { res *= x; res %= mod; } x *= x; x %= mod; p >>= 1; } return res; } const ll inv2 = fastpow(2, mod - 2); void solve() { ll n; scanf("%ld", &n); ll *a = (ll *)malloc(n * sizeof(ll)); for (ll i = 0; i < n; i++) { scanf("%ld", &a[i]); } ll *u = (ll *)malloc(n * sizeof(ll)); ll *v = (ll *)malloc(n * sizeof(ll)); for (ll i = 0; i < n - 1; i++) { scanf("%ld %ld", &u[i], &v[i]); u[i]--; v[i]--; } map<ll, vector<ll>> N; for (ll i = 0; i < n - 1; i++) { N[u[i]].push_back(v[i]); N[v[i]].push_back(u[i]); } map<ll, vector<ll>> C; auto child = [&](ll p, ll u, auto self) -> void { C[p].push_back(u); for (ll v : N[u]) { if (v == p) continue; self(u, v, self); } }; child(-1, 0, child); vector<ll> F(n, -1); auto f = [&](ll u, auto self) -> ll { if (F[u] == -1) { F[u] = a[u]; F[u] %= mod; for (ll v : C[u]) { F[u] += a[u] * self(v, self); F[u] %= mod; } } return F[u]; }; vector<ll> G(n, -1); auto g = [&](ll u) -> ll { if (G[u] == -1) { ll S = 0, s = 0; for (ll v : C[u]) { ll value = f(v, f); S += value; S %= mod; s += value * value; s %= mod; } S *= S; S %= mod; G[u] = S; G[u] += mod - s; G[u] %= mod; G[u] *= a[u]; G[u] %= mod; G[u] *= inv2; G[u] %= mod; } return G[u]; }; ll res = 0; auto h = [&](ll u, auto self) -> void { res += g(u); res %= mod; res += f(u, f); res %= mod; for (ll v : C[u]) { self(v, self); } }; h(0, h); for (ll i = 0; i < n; i++) { res += mod - a[i]; res %= mod; } printf("%ld\n", res); } int main() { solve(); }