結果
| 問題 |
No.2115 Making Forest Easy
|
| コンテスト | |
| ユーザー |
milanis48663220
|
| 提出日時 | 2022-10-28 23:57:03 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 1,263 ms / 2,000 ms |
| コード長 | 3,943 bytes |
| コンパイル時間 | 1,263 ms |
| コンパイル使用メモリ | 121,268 KB |
| 最終ジャッジ日時 | 2025-02-08 15:28:40 |
|
ジャッジサーバーID (参考情報) |
judge1 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 50 |
ソースコード
#include <iostream>
#include <algorithm>
#include <iomanip>
#include <vector>
#include <queue>
#include <deque>
#include <set>
#include <map>
#include <tuple>
#include <cmath>
#include <numeric>
#include <functional>
#include <cassert>
#include <atcoder/modint>
#define debug_value(x) cerr << "line" << __LINE__ << ":<" << __func__ << ">:" << #x << "=" << x << endl;
#define debug(x) cerr << "line" << __LINE__ << ":<" << __func__ << ">:" << x << endl;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
using namespace std;
typedef long long ll;
template<typename T>
vector<vector<T>> vec2d(int n, int m, T v){
return vector<vector<T>>(n, vector<T>(m, v));
}
template<typename T>
vector<vector<vector<T>>> vec3d(int n, int m, int k, T v){
return vector<vector<vector<T>>>(n, vector<vector<T>>(m, vector<T>(k, v)));
}
template<typename T>
void print_vector(vector<T> v, char delimiter=' '){
if(v.empty()) {
cout << endl;
return;
}
for(int i = 0; i+1 < v.size(); i++) cout << v[i] << delimiter;
cout << v.back() << endl;
}
using mint = atcoder::modint998244353;
ostream& operator<<(ostream& os, const mint& m){
os << m.val();
return os;
}
int n;
vector<int> g[5000];
bool ok[5000];
using Arr = array<mint, 2>;
array<array<mint, 2>, 5000> dp1;
array<array<mint, 2>, 5000> dp2;
Arr op(Arr a, Arr b){
Arr ans = {{mint(0), mint(0)}};
for(int x = 0; x < 2; x++){
for(int y = 0; y < 2; y++){
// 切る場合
ans[x] += a[x]*b[y];
// つなぐ場合
ans[max(x, y)] += a[x]*b[y];
}
}
return ans;
}
void dfs1(int v, int par){
if(ok[v]) {
dp1[v][0] = 0;
dp1[v][1] = 1;
}else{
dp1[v][0] = 1;
dp1[v][1] = 0;
}
for(int to: g[v]){
if(to == par) continue;
dfs1(to, v);
Arr nx = op(dp1[v], dp1[to]);
for(int x = 0; x < 2; x++) dp1[v][x] = nx[x];
}
}
Arr e() {
return {{mint(1), mint(0)}};
}
void dfs2(int v, int par, Arr from_par){
vector<int> children;
for(int to: g[v]){
if(to == par) continue;
children.push_back(to);
}
Arr cur = {{mint(0), mint(0)}};
if(ok[v]) cur[1] = 1;
else cur[0] = 1;
int m = children.size();
vector<Arr> from_left(m+1);
if(v == 0){
from_left[0] = cur;
dp2[v] = dp1[v];
}else{
from_left[0] = op(cur, from_par);
dp2[v] = op(dp1[v], from_par);
}
for(int i = 0; i < m; i++){
int to = children[i];
from_left[i+1] = op(from_left[i], dp1[to]);
}
Arr from_right = e();
for(int i = m-1; i >= 0; i--){
int to = children[i];
Arr propagate = {{mint(0), mint(0)}};
for(int x = 0; x < 2; x++){
for(int y = 0; y < 2; y++){
propagate[max(x, y)] += from_left[i][x]*from_right[y];
}
}
dfs2(to, v, propagate);
from_right = op(from_right, dp1[to]);
// Arr p = from_left[0];
// for(int j = 0; j < m; j++){
// if(i == j) continue;
// p = op(p, dp1[children[j]]);
// }
// dfs2(to, v, p);
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout << setprecision(10) << fixed;
int n; cin >> n;
vector<int> a(n);
for(int i = 0; i < n; i++) cin >> a[i];
for(int i = 0; i < n-1; i++){
int u, v; cin >> u >> v; u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
mint ans = 0;
for(int x = 1000; x >= 1; x--){
for(int i = 0; i < n; i++){
if(a[i] >= x) ok[i] = true;
}
dfs1(0, -1);
dfs2(0, -1, e());
for(int i = 0; i < n; i++){
ans += dp2[i][1];
}
}
cout << ans << endl;
}
milanis48663220