結果
| 問題 |
No.2598 Kadomatsu on Tree
|
| ユーザー |
|
| 提出日時 | 2024-01-02 20:18:59 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 486 ms / 2,000 ms |
| コード長 | 2,596 bytes |
| コンパイル時間 | 5,760 ms |
| コンパイル使用メモリ | 326,020 KB |
| 実行使用メモリ | 41,252 KB |
| 最終ジャッジ日時 | 2024-09-29 11:34:30 |
| 合計ジャッジ時間 | 24,326 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 60 |
ソースコード
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
#define rep(i,m,n,k) for (int i = (int)(m); i < (int)(n); i += (int)(k))
#define rrep(i,m,n,k) for (int i = (int)(m); i > (int)(n); i += (int)(k))
#define ll long long
#define list(T,A,N) vector<T> A(N);for(int i=0;i<(int)(N);i++){cin >> A[i];}
using mint = modint998244353;
tuple<vector<long long>,vector<long long>, vector<long long>> sub_par_dist(vector<vector<long long >> e, long long root){
long long N = e.size();
vector<long long> par(N,-1);
vector<long long> sub(N,-1);
vector<long long> dist(N,-1);
queue<long long> v;
dist[root] = 0;
v.push(root);
long long x;
while (!v.empty()){
x = v.front();v.pop();
for (auto ix:e[x]){
if (dist[ix]!=-1) continue;
dist[ix] = dist[x] + 1;
v.push(ix);
}
}
vector<pair<long long,long long>> H;
for (long long i=0;i<N;i++){
H.push_back({-dist[i],i});
}
sort(H.begin(),H.end());
long long tmp;
for (auto [h,i]: H){
tmp = 1;
for (auto ix:e[i]){
if (sub[ix]==-1){
par[i] = ix;
}
else{
tmp += sub[ix];
}
}
sub[i] = tmp;
}
return {sub,par,dist};
}
int main(){
ll N;
cin >> N;
vector<vector<ll>> e(N);
ll u,v;
rep(_,0,N-1,1){
cin >> u >> v;
u -= 1;
v -= 1;
e[u].push_back(v);
e[v].push_back(u);
}
list(ll,A,N);
auto [sub,par,dist] = sub_par_dist(e,0);
mint ans = 0;
mint sx,sx2,sy,sy2;
rep(i,0,N,1){
sx = mint(0);
sy = mint(0);
sx2 = mint(0);
sy2 = mint(0);
for(ll ix:e[i]){
if(par[i]==ix){
if(A[ix]>A[i]){
sx += mint(N-sub[i]);
sx2 += mint((N-sub[i])*(N-sub[i]));
}
else if(A[ix]<A[i]){
sy += mint(N-sub[i]);
sy2 += mint((N-sub[i])*(N-sub[i]));
}
}
else{
if(A[ix]>A[i]){
sx += mint(sub[ix]);
sx2 += mint(sub[ix]*sub[ix]);
}
else if(A[ix]<A[i]){
sy += mint(sub[ix]);
sy2 += mint(sub[ix]*sub[ix]);
}
}
}
ans += (sx*sx - sx2)/mint(2);
ans += (sy*sy - sy2)/mint(2);
}
cout << ans.val() << endl;
return 0;
}