結果

問題 No.2949 Product on Tree
ユーザー ATMATM
提出日時 2024-10-25 22:35:08
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 385 ms / 2,000 ms
コード長 5,403 bytes
コンパイル時間 1,921 ms
コンパイル使用メモリ 178,820 KB
実行使用メモリ 36,372 KB
最終ジャッジ日時 2024-10-25 22:35:30
合計ジャッジ時間 19,720 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,816 KB
testcase_01 AC 2 ms
6,816 KB
testcase_02 AC 2 ms
6,820 KB
testcase_03 AC 347 ms
17,120 KB
testcase_04 AC 318 ms
16,888 KB
testcase_05 AC 328 ms
17,276 KB
testcase_06 AC 318 ms
17,240 KB
testcase_07 AC 328 ms
17,004 KB
testcase_08 AC 337 ms
17,656 KB
testcase_09 AC 330 ms
17,888 KB
testcase_10 AC 353 ms
18,520 KB
testcase_11 AC 322 ms
19,608 KB
testcase_12 AC 333 ms
22,752 KB
testcase_13 AC 319 ms
23,288 KB
testcase_14 AC 343 ms
25,524 KB
testcase_15 AC 334 ms
28,584 KB
testcase_16 AC 341 ms
25,944 KB
testcase_17 AC 366 ms
26,920 KB
testcase_18 AC 345 ms
26,764 KB
testcase_19 AC 342 ms
29,540 KB
testcase_20 AC 342 ms
28,776 KB
testcase_21 AC 370 ms
29,148 KB
testcase_22 AC 338 ms
30,816 KB
testcase_23 AC 326 ms
17,536 KB
testcase_24 AC 359 ms
17,552 KB
testcase_25 AC 326 ms
17,432 KB
testcase_26 AC 328 ms
17,396 KB
testcase_27 AC 326 ms
17,696 KB
testcase_28 AC 356 ms
17,600 KB
testcase_29 AC 339 ms
18,016 KB
testcase_30 AC 329 ms
18,816 KB
testcase_31 AC 362 ms
20,272 KB
testcase_32 AC 342 ms
20,940 KB
testcase_33 AC 342 ms
25,740 KB
testcase_34 AC 351 ms
31,292 KB
testcase_35 AC 370 ms
27,328 KB
testcase_36 AC 356 ms
33,284 KB
testcase_37 AC 361 ms
33,728 KB
testcase_38 AC 385 ms
29,972 KB
testcase_39 AC 351 ms
30,528 KB
testcase_40 AC 352 ms
35,868 KB
testcase_41 AC 370 ms
32,500 KB
testcase_42 AC 347 ms
36,372 KB
testcase_43 AC 152 ms
16,544 KB
testcase_44 AC 154 ms
16,648 KB
testcase_45 AC 198 ms
19,952 KB
testcase_46 AC 177 ms
18,436 KB
testcase_47 AC 157 ms
14,048 KB
testcase_48 AC 185 ms
18,848 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

#define CPP_STR(x) CPP_STR_I(x)
#define CPP_CAT(x, y) CPP_CAT_I(x, y)
#define CPP_STR_I(args...) #args
#define CPP_CAT_I(x, y) x##y

#define ASSERT(expr...) assert((expr))

using i8 = int8_t;
using u8 = uint8_t;
using i16 = int16_t;
using u16 = uint16_t;
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;

using f32 = float;
using f64 = double;
// }}}

constexpr i64 INF = 1'010'000'000'000'000'017LL;

constexpr i64 MOD = 998244353LL;

constexpr f64 EPS = 1e-12;

constexpr f64 PI = 3.14159265358979323846;

#define M5 100007
#define M9 1000000000

#define F first
#define S second

// util {{{
#define FOR(i, start, end) for (i64 i = (start), CPP_CAT(i, xxxx_end) = (end); i < CPP_CAT(i, xxxx_end); ++i)
#define REP(i, n) FOR(i, 0, n)

#define all(x) (x).begin(), (x).end()
#define ll long long int
#define VI vector<ll>
#define VVI vector<VI>

#define ISD true
#define debug(x) \
    if (ISD)     \
    cout << #x << ": " << x << endl

template <typename T, typename U, typename Comp = less<>>
bool chmax(T &xmax, const U &x, Comp comp = {})
{
    if (comp(xmax, x))
    {
        xmax = x;
        return true;
    }
    return false;
}

template <typename T, typename U, typename Comp = less<>>
bool chmin(T &xmin, const U &x, Comp comp = {})
{
    if (comp(x, xmin))
    {
        xmin = x;
        return true;
    }
    return false;
}

template <long long mod>
struct modint
{
    modint(ll v = 0) : value(normalize(v)) {}
    ll val() const { return value; }
    void normalize() { value = normalize(value); }
    ll normalize(ll v)
    {
        if (v <= mod && v >= -mod)
        {
            if (v < 0)
                v += mod;
            return v;
        }
        if (v > 0 && v < 2 * mod)
        {
            v -= mod;
            return v;
        }
        if (v < 0 && v > -2 * mod)
        {
            v += 2 * mod;
            return v;
        }
        v %= mod;
        if (v < 0)
            v += mod;
        return v;
    }
    modint<mod> &operator=(ll v)
    {
        value = normalize(v);
        return *this;
    }
    bool operator==(const modint &o) const { return value == o.val(); }
    bool operator!=(const modint &o) const { return value != o.val(); }
    const modint &operator+() const { return *this; }
    const modint &operator-() const { return value ? mod - value : 0; }
    const modint operator+(const modint &o) const
    {
        return value + o.val();
    }
    modint &operator+=(const modint &o)
    {
        value += o.val();
        if (value >= mod)
            value -= mod;
        return *this;
    }
    const modint operator-(const modint &o) const
    {
        return value - o.val();
    }
    modint &operator-=(const modint &o)
    {
        value -= o.val();
        if (value < 0)
            value += mod;
        return *this;
    }
    const modint operator*(const modint &o) const
    {
        return (value * o.val()) % mod;
    }
    modint &operator*=(const modint &o)
    {
        value *= o.val();
        value %= mod;
        return *this;
    }
    const modint operator/(const modint &o) const { return operator*(o.inv()); }
    modint &operator/=(const modint &o) { return operator*=(o.inv()); }
    const modint pow(ll n) const
    {
        modint ans = 1, x(value);
        while (n > 0)
        {
            if (n & 1)
                ans *= x;
            x *= x;
            n >>= 1;
        }
        return ans;
    }
    const modint inv() const
    {
        ll a = value, b = mod, u = 1, v = 0;
        while (b)
        {
            ll t = a / b;
            a -= t * b;
            swap(a, b);
            u -= t * v;
            swap(u, v);
        }
        return u;
    }
    friend ostream &operator<<(ostream &os, const modint &x)
    {
        return os << x.val();
    }
    template <typename T>
    friend modint operator+(T t, const modint &o)
    {
        return o + t;
    }
    template <typename T>
    friend modint operator-(T t, const modint &o)
    {
        return -o + t;
    }
    template <typename T>
    friend modint operator*(T t, const modint &o)
    {
        return o * t;
    }
    template <typename T>
    friend modint operator/(T t, const modint &o)
    {
        return o.inv() * t;
    }

private:
    ll value;
};

using modint998244353 = modint<998244353>;

int main()
{
    int N;
    cin >> N;
    vector<ll> A(N);
    REP(i, N)
    cin >> A[i];
    vector<vector<int>> G(N);
    REP(i, N - 1)
    {
        int u, v;
        cin >> u >> v;
        u--;
        v--;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    using mint = modint998244353;
    vector<mint> dp(N);
    mint ans = 0;
    function<mint(ll, ll)> dfs = [&](ll v, ll p)
    {
        vector<mint> s;
        mint sum = 0;
        for (auto g : G[v])
        {
            if (g == p)
            {
                continue;
            }
            mint s0 = dfs(g, v);
            s.push_back(s0);
            sum += s0;
        }
        mint sum0 = 0;
        dp[v] += A[v];
        for (auto s0 : s)
        {
            dp[v] += s0 * A[v];
            sum0 += s0 * (sum - s0) * A[v];
        }
        ans += sum0 / 2 + dp[v] - A[v];
        // cout << v << " " << ans.val() << " " << dp[v] << endl;
        return dp[v];
    };
    dfs(0, -1);
    cout << ans.val() << endl;
}
0