結果

問題 No.2587 Random Walk on Tree
ユーザー chineristAC
提出日時 2023-12-16 00:33:26
言語 C++17(gcc12)
(gcc 12.3.0 + boost 1.87.0)
結果
AC  
実行時間 7,263 ms / 10,000 ms
コード長 9,039 bytes
コンパイル時間 23,253 ms
コンパイル使用メモリ 291,048 KB
最終ジャッジ日時 2025-02-18 11:43:57
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 37
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

// -fsanitize=undefined,
// #define _GLIBCXX_DEBUG
#pragma GCC target("avx2")
#pragma GCC optimize("unroll-loops")
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <queue>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <random>
#include <stdio.h>
#include <fstream>
#include <functional>
#include <cassert>
#include <unordered_map>
#include <bitset>
#include <chrono>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using namespace atcoder;
#define rep(i,n) for (int i=0;i<n;i+=1)
#define rrep(i,n) for (int i=n-1;i>-1;i--)
#define pb push_back
#define all(x) (x).begin(), (x).end()
#define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << " )\n";
template<class T>
using vec = vector<T>;
template<class T>
using vvec = vec<vec<T>>;
template<class T>
using vvvec = vec<vvec<T>>;
using ll = long long;
using pii = pair<int,int>;
using pll = pair<ll,ll>;
template<class T>
bool chmin(T &a, T b){
if (a>b){
a = b;
return true;
}
return false;
}
template<class T>
bool chmax(T &a, T b){
if (a<b){
a = b;
return true;
}
return false;
}
template<class T>
T sum(vec<T> x){
T res=0;
for (auto e:x){
res += e;
}
return res;
}
template<class T>
void printv(vec<T> x){
for (auto e:x){
cout<<e<<" ";
}
cout<<endl;
}
template<class T,class U>
ostream& operator<<(ostream& os, const pair<T,U>& A){
os << "(" << A.first <<", " << A.second << ")";
return os;
}
template<class T>
ostream& operator<<(ostream& os, const set<T>& S){
os << "set{";
for (auto a:S){
os << a;
auto it = S.find(a);
it++;
if (it!=S.end()){
os << ", ";
}
}
os << "}";
return os;
}
using mint = modint998244353;
ostream& operator<<(ostream& os, const mint& a){
os << a.val();
return os;
}
template<class T>
ostream& operator<<(ostream& os, const vec<T>& A){
os << "[";
rep(i,A.size()){
os << A[i];
if (i!=A.size()-1){
os << ", ";
}
}
os << "]" ;
return os;
}
template <typename T> vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < (int)b.size(); i++) {
a[i] += b[i];
}
return a;
}
template <typename T> vector<T> operator+(const vector<T>& a, const vector<T>& b) {
vector<T> c = a;
return c += b;
}
vector<mint> calc_prod(vector<vector<mint>> polys){
if (polys.empty()) return {1};
deque<vec<mint>> deq;
for (auto f:polys) deq.push_back(f);
while (deq.size() > 1){
auto f = deq.front(); deq.pop_front();
auto g = deq.front(); deq.pop_front();
deq.push_back(convolution(f,g));
}
return deq[0];
}
pair<vector<mint>,vector<mint>> merge_child(vector<pair<vec<mint>,vec<mint>>> child_polys){
if (child_polys.empty()){
return {{1},{1,-1}};
}
deque<pair<vec<mint>,vec<mint>>> deq;
for (auto e:child_polys) deq.push_back(e);
while (deq.size() > 1){
auto [f0,f1] = deq.front(); deq.pop_front();
auto [g0,g1] = deq.front(); deq.pop_front();
auto h1 = convolution(f1,g1);
auto h0 = convolution(g0,f1) + convolution(f0,g1);
deq.push_back({h0,h1});
}
auto [A,B] = deq.front();
vector<mint> res0 = B;
/*res1 = A * (-x^2) + B * (1-x) */
//debug(child_polys);
//debug(A);
//debug(B);
vector<mint> res1(max(A.size()+2,B.size()+1),0);
rep(i,A.size()) res1[i+2] -= A[i];
rep(i,B.size()) {
res1[i] += B[i];
res1[i+1] -= B[i];
}
return {res0,res1};
}
pair<vector<mint>,vector<mint>> merge_path(vector<pair<vector<mint>,vector<mint>>> polys){
int n = polys.size();
if (n & 1){
polys.push_back({{0},{1}});
}
n = polys.size();
//debug(polys);
vector<array<vector<mint>,4>> nxt_polys;
for (int i=0;i<n;i+=2){
/*0: 1:
0+0->00 or 11(-x^2)
0+1->01
1+0->10
1+1->11
*/
auto [f0,f1] = polys[i];
auto [g0,g1] = polys[i+1];
array<vector<mint>,4> merge;
merge[1] = convolution(f0,g1);
merge[2] = convolution(f1,g0);
merge[3] = convolution(f1,g1);
auto h = convolution(f0,g0);
merge[0] = h;
int p = h.size();
if (merge[3].size() < p+2){
merge[3].resize(p+2);
}
rep(i,p) merge[3][i+2] -= h[i];
nxt_polys.push_back(merge);
}
auto calc = [&](auto self,int l,int r)->array<vector<mint>,4> {
if (r-l==1){
return nxt_polys[l];
}
int mid = (l+r)>>1;
auto left = self(self,l,mid);
auto right = self(self,mid,r);
int merge_n = 0;
rep(i,4){
rep(j,4){
int nxt_ocu = (i & 2) + (j & 1);
int mid_ocu = (i & 1) + (j & 2);
int h_deg = left[i].size() + right[j].size() - 2;
//debug(i);
//debug(j);
//debug(h_deg);
if (mid_ocu == 0){
chmax(merge_n,h_deg+2);
}
else if (mid_ocu == 3){
chmax(merge_n,h_deg);
}
}
}
merge_n += 1;
//debug(merge_n);
array<vector<mint>,4> merge;
rep(i,4) merge[i] = vector<mint>(merge_n,0);
rep(i,4){
rep(j,4){
int nxt_ocu = (i & 2) + (j & 1);
int mid_ocu = (i & 1) + (j & 2);
if (mid_ocu!=0 && mid_ocu!=3) continue;
auto h = convolution(left[i],right[j]);
//assert ((h.size()-1) == left[i].size() + right[j].size() - 2);
if (mid_ocu == 0){
//assert (h.size()+1 < merge_n);
rep(k,h.size()){
merge[nxt_ocu][k+2] -= h[k];
}
}
else if (mid_ocu == 3){
//debug(i);
//debug(j);
//debug(h.size());
//assert (h.size()-1 < merge_n);
rep(k,h.size()){
merge[nxt_ocu][k] += h[k];
}
}
}
}
return merge;
};
auto merged = calc(calc,0,nxt_polys.size());
vector<mint> res0 = merged[1], res1 = merged[3];
return {res0,res1};
}
mint BostanMori2(vec<mint> P,vec<mint> Q, long long N){
for (; N; N>>= 1){
vec<mint> Q_neg = {all(Q)};
for (int i=1;i<Q_neg.size();i+=2) Q_neg[i] *= -1;
P = convolution(P,Q_neg);
Q = convolution(Q,Q_neg);
int deg_P = P.size() - 1, deg_Q = Q.size() - 1;
vec<mint> nxt_P((deg_P/2+1),0),nxt_Q((deg_Q/2+1),0);
for (int i = (N & 1);i <= deg_P; i+=2){
nxt_P[i>>1] = P[i];
}
for (int i = 0; i <= deg_Q; i+=2){
nxt_Q[i>>1] = Q[i];
}
swap(P,nxt_P);
swap(Q,nxt_Q);
}
return P[0];
}
void solve(){
int N,M,S,T;
cin>>N>>M>>S>>T;
S--; T--;
vec<vec<int>> edge(N);
rep(i,N-1){
int a,b;
cin>>a>>b;
a--; b--;
edge[a].push_back(b);
edge[b].push_back(a);
}
vec<int> parent(N,-1), sz(N,1);
vec<int> heavy_child(N,-1);
int root = T;
auto dfs0 = [&](auto self,int v,int pv)->void {
int tmp_max_sz = 0;
for (auto nv:edge[v]){
if (nv == pv) continue;
parent[nv] = v;
self(self,nv,v);
sz[v] += sz[nv];
if (sz[nv] > tmp_max_sz){
tmp_max_sz = sz[nv];
heavy_child[v] = nv;
}
}
};
dfs0(dfs0,root,-1);
vector<int> on_st_path(N,0);
on_st_path[T] = 1;
{
int pos = S;
while (pos!=T){
on_st_path[pos] = 1;
pos = parent[pos];
}
}
vector<int> to_parent_is_light(N,1);
rep(v,N){
if (heavy_child[v]!=-1){
to_parent_is_light[heavy_child[v]] = 0;
}
}
rep(v,N){
if (on_st_path[v]) to_parent_is_light[v] = 1;
}
vec<pair<vec<mint>,vec<mint>>> dp(N);
auto dfs1 = [&](auto self,int v,int pv)->void {
vector<pair<vec<mint>,vec<mint>>> child_polys;
for (auto nv:edge[v]){
if (nv == parent[v] || on_st_path[nv]) continue;
self(self,nv,v);
if (nv!=heavy_child[v]){
child_polys.push_back(dp[nv]);
dp[nv].first.clear();
dp[nv].second.clear();
}
}
dp[v] = merge_child(child_polys);
if (to_parent_is_light[v]){
vector<pair<vec<mint>,vec<mint>>> path_polys = {dp[v]};
int pos = heavy_child[v];
while (pos!=-1 && !on_st_path[pos]){
path_polys.push_back(dp[pos]);
pos = heavy_child[pos];
}
dp[v] = merge_path(path_polys);
}
};
rep(v,N){
if (on_st_path[v]) dfs1(dfs1,v,-1);
}
vector<pair<vec<mint>,vec<mint>>> st_path_polys;
vector<vec<mint>> st_path_polys_non;
{
int pos = S;
while (pos!=-1){
st_path_polys.push_back(dp[pos]);
st_path_polys_non.push_back(dp[pos].first);
pos = parent[pos];
}
}
//debug(st_path_polys);
//debug(st_path_polys_non);
vector<mint> q_det = merge_path(st_path_polys).second;
vector<mint> p_det = calc_prod(st_path_polys_non);
//debug(p_det);
//debug(q_det);
int D = int(st_path_polys.size()) - 1;
if (M < D){
cout << 0 << "\n";
return ;
}
auto ans = BostanMori2(p_det,q_det,M-D);
cout << ans << "\n";
}
int main(){
ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int T = 1;
while (T--){
solve();
}
}
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0