結果
問題 |
No.2598 Kadomatsu on Tree
|
ユーザー |
|
提出日時 | 2024-01-02 20:18:59 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 486 ms / 2,000 ms |
コード長 | 2,596 bytes |
コンパイル時間 | 5,760 ms |
コンパイル使用メモリ | 326,020 KB |
実行使用メモリ | 41,252 KB |
最終ジャッジ日時 | 2024-09-29 11:34:30 |
合計ジャッジ時間 | 24,326 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 60 |
ソースコード
#include <bits/stdc++.h> #include <atcoder/all> using namespace std; using namespace atcoder; #define rep(i,m,n,k) for (int i = (int)(m); i < (int)(n); i += (int)(k)) #define rrep(i,m,n,k) for (int i = (int)(m); i > (int)(n); i += (int)(k)) #define ll long long #define list(T,A,N) vector<T> A(N);for(int i=0;i<(int)(N);i++){cin >> A[i];} using mint = modint998244353; tuple<vector<long long>,vector<long long>, vector<long long>> sub_par_dist(vector<vector<long long >> e, long long root){ long long N = e.size(); vector<long long> par(N,-1); vector<long long> sub(N,-1); vector<long long> dist(N,-1); queue<long long> v; dist[root] = 0; v.push(root); long long x; while (!v.empty()){ x = v.front();v.pop(); for (auto ix:e[x]){ if (dist[ix]!=-1) continue; dist[ix] = dist[x] + 1; v.push(ix); } } vector<pair<long long,long long>> H; for (long long i=0;i<N;i++){ H.push_back({-dist[i],i}); } sort(H.begin(),H.end()); long long tmp; for (auto [h,i]: H){ tmp = 1; for (auto ix:e[i]){ if (sub[ix]==-1){ par[i] = ix; } else{ tmp += sub[ix]; } } sub[i] = tmp; } return {sub,par,dist}; } int main(){ ll N; cin >> N; vector<vector<ll>> e(N); ll u,v; rep(_,0,N-1,1){ cin >> u >> v; u -= 1; v -= 1; e[u].push_back(v); e[v].push_back(u); } list(ll,A,N); auto [sub,par,dist] = sub_par_dist(e,0); mint ans = 0; mint sx,sx2,sy,sy2; rep(i,0,N,1){ sx = mint(0); sy = mint(0); sx2 = mint(0); sy2 = mint(0); for(ll ix:e[i]){ if(par[i]==ix){ if(A[ix]>A[i]){ sx += mint(N-sub[i]); sx2 += mint((N-sub[i])*(N-sub[i])); } else if(A[ix]<A[i]){ sy += mint(N-sub[i]); sy2 += mint((N-sub[i])*(N-sub[i])); } } else{ if(A[ix]>A[i]){ sx += mint(sub[ix]); sx2 += mint(sub[ix]*sub[ix]); } else if(A[ix]<A[i]){ sy += mint(sub[ix]); sy2 += mint(sub[ix]*sub[ix]); } } } ans += (sx*sx - sx2)/mint(2); ans += (sy*sy - sy2)/mint(2); } cout << ans.val() << endl; return 0; }