結果
問題 | No.2949 Product on Tree |
ユーザー | sibasyun |
提出日時 | 2024-11-03 22:36:05 |
言語 | C++23(gcc13) (gcc 13.2.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 451 ms / 2,000 ms |
コード長 | 5,148 bytes |
コンパイル時間 | 7,635 ms |
コンパイル使用メモリ | 338,580 KB |
実行使用メモリ | 32,924 KB |
最終ジャッジ日時 | 2024-11-03 22:36:38 |
合計ジャッジ時間 | 31,587 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,820 KB |
testcase_01 | AC | 4 ms
6,820 KB |
testcase_02 | AC | 3 ms
6,816 KB |
testcase_03 | AC | 428 ms
16,344 KB |
testcase_04 | AC | 357 ms
15,968 KB |
testcase_05 | AC | 366 ms
16,448 KB |
testcase_06 | AC | 405 ms
16,464 KB |
testcase_07 | AC | 406 ms
16,152 KB |
testcase_08 | AC | 374 ms
16,804 KB |
testcase_09 | AC | 361 ms
17,160 KB |
testcase_10 | AC | 395 ms
17,568 KB |
testcase_11 | AC | 369 ms
18,376 KB |
testcase_12 | AC | 365 ms
21,104 KB |
testcase_13 | AC | 391 ms
21,448 KB |
testcase_14 | AC | 360 ms
23,416 KB |
testcase_15 | AC | 375 ms
25,900 KB |
testcase_16 | AC | 403 ms
23,888 KB |
testcase_17 | AC | 375 ms
24,800 KB |
testcase_18 | AC | 389 ms
24,476 KB |
testcase_19 | AC | 403 ms
27,028 KB |
testcase_20 | AC | 376 ms
26,372 KB |
testcase_21 | AC | 387 ms
26,812 KB |
testcase_22 | AC | 375 ms
27,848 KB |
testcase_23 | AC | 356 ms
16,656 KB |
testcase_24 | AC | 358 ms
16,692 KB |
testcase_25 | AC | 365 ms
16,588 KB |
testcase_26 | AC | 393 ms
16,652 KB |
testcase_27 | AC | 363 ms
16,688 KB |
testcase_28 | AC | 363 ms
16,880 KB |
testcase_29 | AC | 402 ms
17,080 KB |
testcase_30 | AC | 379 ms
17,804 KB |
testcase_31 | AC | 379 ms
19,120 KB |
testcase_32 | AC | 421 ms
19,376 KB |
testcase_33 | AC | 405 ms
23,420 KB |
testcase_34 | AC | 399 ms
28,612 KB |
testcase_35 | AC | 427 ms
25,172 KB |
testcase_36 | AC | 399 ms
30,264 KB |
testcase_37 | AC | 404 ms
30,604 KB |
testcase_38 | AC | 425 ms
27,252 KB |
testcase_39 | AC | 402 ms
27,928 KB |
testcase_40 | AC | 451 ms
32,516 KB |
testcase_41 | AC | 404 ms
29,476 KB |
testcase_42 | AC | 400 ms
32,924 KB |
testcase_43 | AC | 151 ms
13,784 KB |
testcase_44 | AC | 154 ms
14,040 KB |
testcase_45 | AC | 232 ms
16,936 KB |
testcase_46 | AC | 177 ms
15,656 KB |
testcase_47 | AC | 128 ms
12,268 KB |
testcase_48 | AC | 181 ms
16,088 KB |
ソースコード
#include <bits/stdc++.h> #include <atcoder/all> #include <algorithm> #include <iostream> #include <iomanip> #include <math.h> #include <random> #include <chrono> #include <cstdint> using namespace std; using namespace atcoder; using mint = modint998244353; // using mint = modint1000000007; using vi = vector<int>; using vvi = vector<vector<int>>; using vvvi = vector<vector<vector<int>>>; using vl = vector<long long>; using vvl = vector<vector<long long>>; using vvvl = vector<vector<vector<long long>>>; using vm = vector<mint>; using vvm = vector<vector<mint>>; using vvvm = vector<vector<vector<mint>>>; using ll = long long; template <class T> using max_heap = priority_queue<T>; template <class T> using min_heap = priority_queue<T, vector<T>, greater<>>; #define rep(i, n) for (int i = 0; i < (int)(n); i++) #define rep2(i, f, n) for (int i = (int) f; i < (int)(n); i++) #define repd(i, n, l) for (int i = (int) n; i >= (int) l; i--) #define all(p) p.begin(),p.end() vector<pair<int, int>> dydx{{-1, 0}, {1, 0}, {0, -1 }, {0, 1}}; const ll inf = 1LL << 60; void print(){ putchar(' '); } void print(bool a){ printf("%d", a); } void print(int a){ printf("%d", a); } void print(unsigned a){ printf("%u", a); } void print(long a){ printf("%ld", a); } void print(long long a){ printf("%lld", a); } void print(unsigned long long a){ printf("%llu", a); } void print(char a){ printf("%c", a); } void print(char a[]){ printf("%s", a); } void print(const char a[]){ printf("%s", a); } void print(float a){ printf("%.15f", a); } void print(double a){ printf("%.15f", a); } void print(long double a){ printf("%.15Lf", a); } void print(const string& a){ for(auto&& i : a) print(i); } template<class T> void print(const complex<T>& a){ if(a.real() >= 0) print('+'); print(a.real()); if(a.imag() >= 0) print('+'); print(a.imag()); print('i'); } template<class T> void print(const vector<T>&); template<class T, size_t size> void print(const array<T, size>&); template<class T, class L> void print(const pair<T, L>& p); template<class T, size_t size> void print(const T (&)[size]); template<class T> void print(const vector<T>& a){ if(a.empty()) return; print(a[0]); for(auto i = a.begin(); ++i != a.end(); ){ putchar(' '); print(*i); } } template<class T> void print(const deque<T>& a){ if(a.empty()) return; print(a[0]); for(auto i = a.begin(); ++i != a.end(); ){ putchar(' '); print(*i); } } template<class T, size_t size> void print(const array<T, size>& a){ print(a[0]); for(auto i = a.begin(); ++i != a.end(); ){ putchar(' '); print(*i); } } template<class T, class L> void print(const pair<T, L>& p){ print(p.first); putchar(' '); print(p.second); } template<class T, size_t size> void print(const T (&a)[size]){ print(a[0]); for(auto i = a; ++i != end(a); ){ putchar(' '); print(*i); } } template<class T> void print(const T& a){ cout << a; } constexpr ll ten(int n){ return n==0?1:ten(n-1)*10; } vector<vector<int>> unweighted_graph(int n, int m){ vector<vector<int>> ret(n); while (m--){ int a, b; cin >> a >> b; a--; b--; ret[a].push_back(b); ret[b].push_back(a); } return ret; } vector<vector<pair<int, long long>>> weighted_graph(int n, int m){ vector<vector<pair<int, long long>>> ret(n); while (m--){ int a, b; long long c; cin >> a >> b >> c; a--, b--; ret[a].push_back({b, c}); ret[b].push_back({a, c}); } return ret; } template<typename T> int argmin(vector<T> &a){ T mi = *min_element(all(a)); for (int i = 0; i < a.size(); i++){ if (a[i] == mi) return i; } } template<typename T> int argmax(vector<T> &a){ T ma = *max_element(all(a)); for (int i = 0; i < a.size(); i++){ if (a[i] == ma) return i; } } mint ans = 0; void dfs(int n, int p, vvi &G, vm &dp, vm &A){ dp[n] = A[n]; // mint sm = 0; // int cnt = 0; for(int nxt : G[n]){ if (nxt == p) continue; // cnt++; dfs(nxt, n, G, dp, A); dp[n] += A[n] * dp[nxt]; } // if (cnt) dp[n] = sm * dp[n]; // else dp[n] = A[n]; // dp[n] += sm * A[n]; // cout << n << ' ' << dp[n].val() << endl; } void dfs2(int n, int p, vvi &G, vm &dp, vm &A){ ans += dp[n] - A[n]; // cout << n << ' ' << dp[n].val() << endl; for(int nxt : G[n]){ if (nxt == p) continue; mint pre = dp[n]; mint pre2 = dp[nxt]; dp[n] -= dp[nxt] * A[n]; dp[nxt] += dp[n] * A[nxt]; // if (G[n].size() > 1) dp[n] -= dp[nxt] * A[n]; // else dp[n] = A[n]; // if (G[nxt].size() > 1) dp[nxt] += dp[n] * A[nxt]; // else dp[nxt] = A[nxt] * dp[n]; // dp[nxt] += dp[n] * A[nxt]; dfs2(nxt, n, G, dp, A); dp[nxt] = pre2; dp[n] = pre; } } int main() { int N; cin >> N; vi a(N); rep(i, N) cin >> a[i]; vm A(N); rep(i, N) A[i] = a[i]; auto G = unweighted_graph(N, N-1); vm dp(N, 0); dfs(0, -1, G, dp, A); dfs2(0, -1, G, dp, A); ans *= mint(2).inv(); cout << ans.val() << endl; // rep(i, N) cout << dp[i].val() << ' '; return 0; }