結果
| 問題 |
No.1124 Earthquake Safety
|
| コンテスト | |
| ユーザー |
pockyny
|
| 提出日時 | 2024-10-20 15:39:17 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,515 bytes |
| コンパイル時間 | 1,077 ms |
| コンパイル使用メモリ | 82,556 KB |
| 最終ジャッジ日時 | 2025-02-24 21:49:40 |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 16 WA * 42 |
ソースコード
#include <iostream>
#include <atcoder/modint>
#include <vector>
using namespace std;
using namespace atcoder;
using mint = modint998244353;
vector<int> G[300010];
mint f[300010],g[300010],h[300010];
int dep[300010];
mint pw(mint a,int x){
mint ret = 1;
while(x){
if(x&1) (ret *= a);
(a *= a); x /= 2;
}
return ret;
}
// (i,j,k)を選んで、連結になる確率を求める
// (i,j,k)の並び替えは全て数えて重複させない形になっている
void dfs(int s,int p){
// cout << "s == " << s << " p == " << p << endl;
// 個数/gを使ったかどうか?
vector<vector<mint>> dpf(4,vector<mint>(2)),ndpf(4,vector<mint>(2));
int i,j;
dpf[0][0] = 1;
for(int v:G[s]){
if(v==p) continue;
dep[v] = dep[s] + 1;
dfs(v,s);
f[s] += f[v]; g[s] += g[v]; h[s] += h[v];
for(i=0;i<=3;i++){
if(i){
for(j=0;j<2;j++) ndpf[i][j] += dpf[i - 1][j]*h[v];
}
if(i>=2){
for(j=0;j<2;j++) ndpf[i][1] += dpf[i - 2][j]*g[v];
}
for(j=0;j<2;j++) ndpf[i][j] += dpf[i][j];
}
for(i=0;i<=3;i++){
for(j=0;j<2;j++) dpf[i][j] = ndpf[i][j], ndpf[i][j] = 0;
}
}
mint pp = pw(2,dep[s]);
mint inv = (mint)1/pp;
h[s] += inv;
g[s] += 2*dpf[2][0]*pp + 2*dpf[1][0];
f[s] += 6*dpf[3][0]*pw(pp,3) + 3*dpf[3][1]*pw(pp,2) + 3*dpf[2][1]*pp + 6*dpf[2][0]*pw(pp,2);
}
int main(){
int i,n; cin >> n;
for(i=0;i<n - 1;i++){
int a,b; cin >> a >> b; a--; b--;
G[a].push_back(b); G[b].push_back(a);
}
dep[0] = 0;
dfs(0,-1);
// (i,i,i)のケース
mint x = n;
// (i,i,j)のケース
mint y = 0;
for(i=0;i<n;i++){
mint p = pw(2,dep[i]);
mint sum = 0,sum2 = 0;
for(int u:G[i]){
if(u==-1 || dep[u]<dep[i]) continue;
sum += h[u];
sum2 += h[u]*h[u];
}
// cout << (sum*4).val() << " " << (sum2*4).val() << " " << p.val() << endl;
y += ((sum*sum - sum2)/2)*pw(p,2);
y += sum*p;
}
y *= 6;
// (i,j,k)のケース
mint z = f[0];
cout << ((x + y + z)*pw(2,n - 1)).val() << "\n";
// for(i=0;i<n;i++){
// mint x = pw(2,n - 1);
// cout << (f[i]*x).val() << " " << (g[i]*x).val() << " " << (h[i]*x).val() << "\n";
// }
// cout << (x*pw(2,n - 1)).val() << " " << (y*pw(2,n - 1)).val() << " " << (z*pw(2,n - 1)).val() << "\n";
}
pockyny