結果
| 問題 |
No.2949 Product on Tree
|
| コンテスト | |
| ユーザー |
ATM
|
| 提出日時 | 2024-10-25 22:35:08 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 385 ms / 2,000 ms |
| コード長 | 5,403 bytes |
| コンパイル時間 | 1,921 ms |
| コンパイル使用メモリ | 178,820 KB |
| 実行使用メモリ | 36,372 KB |
| 最終ジャッジ日時 | 2024-10-25 22:35:30 |
| 合計ジャッジ時間 | 19,720 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 46 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
#define CPP_STR(x) CPP_STR_I(x)
#define CPP_CAT(x, y) CPP_CAT_I(x, y)
#define CPP_STR_I(args...) #args
#define CPP_CAT_I(x, y) x##y
#define ASSERT(expr...) assert((expr))
using i8 = int8_t;
using u8 = uint8_t;
using i16 = int16_t;
using u16 = uint16_t;
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;
using f32 = float;
using f64 = double;
// }}}
constexpr i64 INF = 1'010'000'000'000'000'017LL;
constexpr i64 MOD = 998244353LL;
constexpr f64 EPS = 1e-12;
constexpr f64 PI = 3.14159265358979323846;
#define M5 100007
#define M9 1000000000
#define F first
#define S second
// util {{{
#define FOR(i, start, end) for (i64 i = (start), CPP_CAT(i, xxxx_end) = (end); i < CPP_CAT(i, xxxx_end); ++i)
#define REP(i, n) FOR(i, 0, n)
#define all(x) (x).begin(), (x).end()
#define ll long long int
#define VI vector<ll>
#define VVI vector<VI>
#define ISD true
#define debug(x) \
if (ISD) \
cout << #x << ": " << x << endl
template <typename T, typename U, typename Comp = less<>>
bool chmax(T &xmax, const U &x, Comp comp = {})
{
if (comp(xmax, x))
{
xmax = x;
return true;
}
return false;
}
template <typename T, typename U, typename Comp = less<>>
bool chmin(T &xmin, const U &x, Comp comp = {})
{
if (comp(x, xmin))
{
xmin = x;
return true;
}
return false;
}
template <long long mod>
struct modint
{
modint(ll v = 0) : value(normalize(v)) {}
ll val() const { return value; }
void normalize() { value = normalize(value); }
ll normalize(ll v)
{
if (v <= mod && v >= -mod)
{
if (v < 0)
v += mod;
return v;
}
if (v > 0 && v < 2 * mod)
{
v -= mod;
return v;
}
if (v < 0 && v > -2 * mod)
{
v += 2 * mod;
return v;
}
v %= mod;
if (v < 0)
v += mod;
return v;
}
modint<mod> &operator=(ll v)
{
value = normalize(v);
return *this;
}
bool operator==(const modint &o) const { return value == o.val(); }
bool operator!=(const modint &o) const { return value != o.val(); }
const modint &operator+() const { return *this; }
const modint &operator-() const { return value ? mod - value : 0; }
const modint operator+(const modint &o) const
{
return value + o.val();
}
modint &operator+=(const modint &o)
{
value += o.val();
if (value >= mod)
value -= mod;
return *this;
}
const modint operator-(const modint &o) const
{
return value - o.val();
}
modint &operator-=(const modint &o)
{
value -= o.val();
if (value < 0)
value += mod;
return *this;
}
const modint operator*(const modint &o) const
{
return (value * o.val()) % mod;
}
modint &operator*=(const modint &o)
{
value *= o.val();
value %= mod;
return *this;
}
const modint operator/(const modint &o) const { return operator*(o.inv()); }
modint &operator/=(const modint &o) { return operator*=(o.inv()); }
const modint pow(ll n) const
{
modint ans = 1, x(value);
while (n > 0)
{
if (n & 1)
ans *= x;
x *= x;
n >>= 1;
}
return ans;
}
const modint inv() const
{
ll a = value, b = mod, u = 1, v = 0;
while (b)
{
ll t = a / b;
a -= t * b;
swap(a, b);
u -= t * v;
swap(u, v);
}
return u;
}
friend ostream &operator<<(ostream &os, const modint &x)
{
return os << x.val();
}
template <typename T>
friend modint operator+(T t, const modint &o)
{
return o + t;
}
template <typename T>
friend modint operator-(T t, const modint &o)
{
return -o + t;
}
template <typename T>
friend modint operator*(T t, const modint &o)
{
return o * t;
}
template <typename T>
friend modint operator/(T t, const modint &o)
{
return o.inv() * t;
}
private:
ll value;
};
using modint998244353 = modint<998244353>;
int main()
{
int N;
cin >> N;
vector<ll> A(N);
REP(i, N)
cin >> A[i];
vector<vector<int>> G(N);
REP(i, N - 1)
{
int u, v;
cin >> u >> v;
u--;
v--;
G[u].push_back(v);
G[v].push_back(u);
}
using mint = modint998244353;
vector<mint> dp(N);
mint ans = 0;
function<mint(ll, ll)> dfs = [&](ll v, ll p)
{
vector<mint> s;
mint sum = 0;
for (auto g : G[v])
{
if (g == p)
{
continue;
}
mint s0 = dfs(g, v);
s.push_back(s0);
sum += s0;
}
mint sum0 = 0;
dp[v] += A[v];
for (auto s0 : s)
{
dp[v] += s0 * A[v];
sum0 += s0 * (sum - s0) * A[v];
}
ans += sum0 / 2 + dp[v] - A[v];
// cout << v << " " << ans.val() << " " << dp[v] << endl;
return dp[v];
};
dfs(0, -1);
cout << ans.val() << endl;
}
ATM