結果
| 問題 |
No.2587 Random Walk on Tree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-06-05 18:36:02 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 4,141 ms / 10,000 ms |
| コード長 | 12,019 bytes |
| コンパイル時間 | 3,046 ms |
| コンパイル使用メモリ | 223,540 KB |
| 実行使用メモリ | 87,104 KB |
| 最終ジャッジ日時 | 2025-06-05 18:37:36 |
| 合計ジャッジ時間 | 84,351 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 37 |
コンパイルメッセージ
main.cpp: In function ‘void main_solve()’:
main.cpp:17:23: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
17 | #define inint(e) scanf("%d",&e)
| ~~~~~^~~~~~~~~
main.cpp:389:9: note: in expansion of macro ‘inint’
389 | inint(n); inll(step);
| ^~~~~
main.cpp:18:22: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
18 | #define inll(e) scanf("%lld",&e)
| ~~~~~^~~~~~~~~~~
main.cpp:389:19: note: in expansion of macro ‘inll’
389 | inint(n); inll(step);
| ^~~~
main.cpp:19:26: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
19 | #define inpr(e1,e2) scanf("%d%d",&e1,&e2)
| ~~~~~^~~~~~~~~~~~~~~~
main.cpp:390:9: note: in expansion of macro ‘inpr’
390 | inpr(stt,ed);
| ^~~~
main.cpp:19:26: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
19 | #define inpr(e1,e2) scanf("%d%d",&e1,&e2)
| ~~~~~^~~~~~~~~~~~~~~~
main.cpp:393:26: note: in expansion of macro ‘inpr’
393 | int x,y; inpr(x,y);
| ^~~~
ソースコード
#include <bits/stdc++.h>
// #include <windows.h>
// #include <bits/extc++.h>
// using namespace __gnu_pbds;
using namespace std;
//#pragma GCC optimize(3)
#define DB double
#define LL long long
#define ULL unsigned long long
#define in128 __int128
#define cint const int
#define cLL const LL
#define For(z,e1,e2) for(int z=(e1);z<=(e2);z++)
#define Rof(z,e1,e2) for(int z=(e2);z>=(e1);z--)
#define For_(z,e1,e2) for(int z=(e1);z<(e2);z++)
#define Rof_(z,e1,e2) for(int z=(e2);z>(e1);z--)
#define inint(e) scanf("%d",&e)
#define inll(e) scanf("%lld",&e)
#define inpr(e1,e2) scanf("%d%d",&e1,&e2)
#define in3(e1,e2,e3) scanf("%d%d%d",&e1,&e2,&e3)
#define outint(e) printf("%d\n",e)
#define outint_(e) printf("%d%c",e," \n"[i==n])
#define outint2_(e,e1,e2) printf("%d%c",e," \n"[(e1)==(e2)])
#define outll(e) printf("%lld\n",e)
#define outll_(e) printf("%lld%c",e," \n"[i==n])
#define outll2_(e,e1,e2) printf("%lld%c",e," \n"[(e1)==(e2)])
#define exc(e) if(e) continue
#define stop(e) if(e) break
#define ret(e) if(e) return
#define pb push_back
#define ft first
#define sc second
#define pii pair<int,int>
#define pli pair<long long,int>
#define vct vector
#define clean(e) while(!e.empty()) e.pop()
#define all(ev) ev.begin(),ev.end()
#define sz(ev) ((int)ev.size())
#define debug(x) printf("%s=%d\n",#x,x)
#define x0 __xx00__
#define y0 __yy00__
#define y1 __yy11__
#define ffo fflush(stdout)
cLL mod=998244353ll,G=404ll;
// cLL mod=1000000007ll;
// cLL mod[2]={1686688681ll,1666888681ll},base[2]={166686661ll,188868881ll};
template <typename Type> void get_min(Type &w1,const Type w2) { if(w2<w1) w1=w2; } template <typename Type> void get_max(Type &w1,const Type w2) { if(w2>w1) w1=w2; }
template <typename Type> Type up_div(Type w1,Type w2) { return (w1/w2+(w1%w2?1:0)); }
template <typename Type> Type gcd(Type X_,Type Y_) { Type R_=X_%Y_; while(R_) { X_=Y_; Y_=R_; R_=X_%Y_; } return Y_; } template <typename Type> Type lcm(Type X_,Type Y_) { return (X_/gcd(X_,Y_)*Y_); }
template <typename Type> Type md(Type w1,const Type w2=mod) { w1%=w2; if(w1<0) w1+=w2; return w1; } template <typename Type> Type md_(Type w1,const Type w2=mod) { w1%=w2; if(w1<=0) w1+=w2; return w1; }
void ex_gcd(LL &X_,LL &Y_,LL A_,LL B_) { if(!B_) { X_=1ll; Y_=0ll; return ; } ex_gcd(Y_,X_,B_,A_%B_); X_=md(X_,B_); Y_=(1ll-X_*A_)/B_; } LL inv(LL A_,LL B_=mod) { LL X_=0ll,Y_=0ll; ex_gcd(X_,Y_,A_,B_); return X_; }
template <typename Type> void add(Type &w1,const Type w2,const Type M_=mod) { w1=md(w1+w2,M_); } void mul(LL &w1,cLL w2,cLL M_=mod) { w1=md(w1*md(w2,M_),M_); } template <typename Type> Type pw(Type X_,Type Y_,Type M_=mod) { Type S_=1; while(Y_) { if(Y_&1) mul(S_,X_,M_); Y_>>=1; mul(X_,X_,M_); } return S_; }
// template <typename Type> Type bk(vector <Type> &V_) { auto T_=V_.back(); V_.pop_back(); return T_; } template <typename Type> Type tp(stack <Type> &V_) { auto T_=V_.top(); V_.pop(); return T_; } template <typename Type> Type frt(queue <Type> &V_) { auto T_=V_.front(); V_.pop(); return T_; }
// template <typename Type> Type bg(set <Type> &V_) { auto T_=*V_.begin(); V_.erase(V_.begin()); return T_; } template <typename Type> Type bk(set <Type> &V_) { auto T_=*prev(V_.end()); V_.erase(*prev(V_.end())); return T_; }
mt19937 gen(time(NULL)); int rd() { return abs((int)gen()); }
int rnd(int l,int r) { return rd()%(r-l+1)+l; }
void Add(LL &w1,cLL w2)
{
if((w1+=w2)>=mod) w1-=mod;
}
cint H=21,L=1<<H|11,BL=20;
vct <LL> z[H+1];
void main_init()
{
For(i,0,H)
{
z[i].resize(1<<i);
LL s=1ll,t=pw(G,(mod-1ll)>>i);
For_(j,0,1<<i) z[i][j]=s,mul(s,t);
}
}
int minsz(int w)
{
return (__lg(w)+(w!=(w&-w)));
}
void NTT(vct <LL> &v,bool op)
{
int d=minsz(sz(v)),l=1<<d;
v.resize(l,0ll);
vct <int> rev(l,0);
For_(k,1,l)
{
rev[k]=(rev[k>>1]>>1)|((k&1)<<d-1);
if(rev[k]<k) swap(v[k],v[rev[k]]);
}
For_(k,0,d)
{
for(int i=0;i<l;i+=(1<<k+1)) For_(j,i,i|(1<<k))
{
int j2=(j|(1<<k));
LL t=v[j2]*z[k+1][i^j]%mod;
v[j2]=v[j]-t; if(v[j2]<0ll) v[j2]+=mod;
v[j]+=t; if(v[j]>=mod) v[j]-=mod;
}
}
ret(!op);
LL t=inv(1ll*l);
for(auto &i:v) mul(i,t);
reverse(v.begin()+1,v.end());
}
vct <LL> operator * (vct <LL> w1,vct <LL> w2)
{
if(w1.empty()||w2.empty()) return {};
int len1=sz(w1),len2=sz(w2);
int len=len1+len2-1,l=1<<minsz(len);
if(min(len1,len2)<=BL)
{
vct <LL> w(len,0ll);
For_(i,0,len1) For_(j,0,len2)
add(w[i+j],w1[i]*w2[j]);
return w;
}
w1.resize(l,0ll),w2.resize(l,0ll);
NTT(w1,false),NTT(w2,false);
For_(i,0,l) mul(w1[i],w2[i]);
NTT(w1,true);
w1.resize(len);
return w1;
}
void operator *= (vct <LL> &w1,vct <LL> w2)
{
w1=w1*w2;
}
vct <LL> operator + (vct <LL> w1,vct <LL> w2)
{
int len1=sz(w1),len2=sz(w2),len=min(len1,len2);
For_(i,0,len) Add(w1[i],w2[i]);
if(len1<len2)
{
w1.resize(len2,0ll);
For_(i,len1,len2) w1[i]=w2[i];
}
return w1;
}
void operator += (vct <LL> &w1,vct <LL> w2)
{
w1=w1+w2;
}
void operator <<= (vct <LL> &w,int t)
{
w.resize(sz(w)+t);
}
void operator >>= (vct <LL> &w,int t)
{
reverse(all(w));
w<<=t;
reverse(all(w));
}
vct <LL> keep(vct <LL> &w,int op)
{
int len=sz(w);
vct <LL> w_;
for(int i=op;i<len;i+=2) w_.pb(w[i]);
return w_;
}
LL gt_div(LL t,vct <LL> w1,vct <LL> w2)
{
reverse(all(w2)); while(w2.back()==0ll) t++,w2.pop_back(); reverse(all(w2));
while(t)
{
auto w=w2; int len=sz(w);
for(int i=1;i<len;i+=2) w[i]=md(-w[i]);
w1*=w,w2*=w;
w2=keep(w2,0),w1=keep(w1,(int)(t&1ll));
t>>=1;
}
if(w1.empty()||w2.empty()) return 0ll;
return (w1[0]*inv(w2[0])%mod);
}
// LL cal_M_th(LL t,vct <LL> w1,vct <LL> w2)
// {
// for(auto &i:w1) i=md(i);
// if(t<sz(w1)) return w1[t];
// for(auto &i:w2) i=md(-i);
// w2[0]=1ll; w1*=w2; w1.resize(sz(w2)-1,0ll);
// return gt_div(t,w1,w2);
// }
/* vct <LL> BM(vct <LL> a)
{
LL lst_del=0ll,del=0ll;
vct <LL> lst,f;
int len=sz(a),tot=0,k=-1;
For_(i,0,len)
{
LL w=0ll;
For_(j,0,min(tot,i-1)) add(w,a[i-j-1]*f[j]);
exc(w==a[i]);
LL del=md(a[i]-w);
if(!~k)
{
k=i,lst_del=del;
f.resize(i+1,0ll);
continue;
}
auto f2=f;
int tot2=sz(lst);
if(tot<i-k+tot2)
{
tot=i-k+tot2;
f.resize(tot,0ll);
}
LL t=del*inv(lst_del)%mod;
add(f[i-k-1],t);
For_(j,0,tot2) add(f[i-k+j],-lst[j]*t);
if(i-sz(f2)>k-tot2)
k=i,lst_del=del,lst=f2;
}
return f;
} */
cint N=1.02e5;
int n; LL step;
int stt,ed;
vct <int> v[N],v_[N];
int up[N]; bool b[N];
LL fac[N],ifac[N];
LL C(int n_,int m_)
{
if(n_<m_||m_<0) return 0ll;
return (fac[n_]*(ifac[n_-m_]*ifac[m_]%mod)%mod);
}
void dfs0(int p,int fa)
{
up[p]=fa;
for(auto i:v[p]) if(i!=fa)
dfs0(i,p);
}
int siz[N],siz2[N];
vct <int> pas;
void dfs01(int p,int fa)
{
pas.pb(p),up[p]=fa;
if(fa) v[p].erase(find(all(v[p]),fa));
siz[p]=1;
ret(v[p].empty());
int len=sz(v[p]),k=0;
For_(i_,0,len)
{
int i=v[p][i_];
// printf("%d->%d\n",p,i); ffo; //E
dfs01(i,p);
if(siz[i]>siz[v[p][k]]) k=i_;
siz[p]+=siz[i];
}
swap(v[p][0],v[p][k]);
}
int tp[N];
void dfs02(int p,bool f)
{
// printf("p=%d\n",p); ffo; //E
tp[p]=(f?p:tp[up[p]]);
for(auto i:v[p]) dfs02(i,i!=v[p][0]);
}
array <vct<LL>,2> dp[N];
vct <LL> work_dp(int avd,int rt)
{
// printf("work_dp %d(%d)\n",rt,avd); ffo; // E
pas.clear();
dfs01(rt,avd);
// printf("dfs01 end\n"); ffo; //E
dfs02(rt,true);
// printf("dfs02 end\n"); ffo; //E
reverse(all(pas));
for(auto p:pas) if(tp[p]==p)
{
// printf("p=%d\n",p); ffo; // E
vct <int> h;
for(int i=p;;i=v[i][0])
{
h.pb(i),b[i]=true;
stop(v[i].empty());
}
// printf("line\n> "); // E
// for(auto i:h) printf("%d ",i); printf("\n"); // E
for(auto i:h)
{
vct <int> v2;
for(auto j:v[i]) if(!b[j]) v2.pb(j);
// printf("init of i=%d\n",i); ffo; // E
siz2[i]=1;
if(v2.empty())
dp[i][0]={1ll},dp[i][1]={};
else
{
for(auto j:v2) dp[j][1]+=dp[j][0],siz2[i]+=siz[j];
auto work=[&](int l,int r,auto &self)->array <vct<LL>,2>
{
if(l==r) return {dp[v2[l]][1],dp[v2[l]][0]};
int mid=l,sum=0,now=siz[v2[l]];
For(k,l,r) sum+=siz[v2[k]];
while(mid+1<r&&abs((now<<1)-sum)>abs((now+siz[v2[mid+1]]<<1)-sum)) now+=siz[v2[++mid]];
auto [l0,l1]=self(l,mid,self);
auto [r0,r1]=self(mid+1,r,self);
return {l0*r0,(l0*r1)+(l1*r0)};
};
dp[i]=work(0,sz(v2)-1,work); dp[i][1]>>=1;
}
// printf("* "); for(auto w:dp[i][0]) printf("%lld ",w); printf("\n"); // E
// printf("* "); for(auto w:dp[i][1]) printf("%lld ",w); printf("\n"); // E
// printf("finish\n",i); ffo; //E
}
auto work=[&](int l,int r,auto &self)->array <array<vct<LL>,2>,2>
{
if(l==r) return (array<array<vct<LL>,2>,2>){(array<vct<LL>,2>){dp[h[l]][0],(vct<LL>){}},(array<vct<LL>,2>){(vct<LL>){},dp[h[l]][1]}};
int mid=l,sum=0,now=siz2[h[l]];
For(k,l,r) sum+=siz2[h[k]];
while(mid+1<r&&abs((now<<1)-sum)>abs((now+siz2[h[mid+1]]<<1)-sum)) now+=siz2[h[++mid]];
auto wl=self(l,mid,self);
auto wr=self(mid+1,r,self);
array <array<vct<LL>,2>,2> w;
For(z1,0,1) For(z2,0,1) w[z1|(l==mid)][z2|(mid+1==r)]+=wl[z1][0]*wr[0][z2];
For(z1,0,1) For(z2,0,1) if(!w[z1][z2].empty()) w[z1][z2]>>=1;
For(z,0,1) wl[z][1]+=wl[z][0],wr[1][z]+=wr[0][z];
For(z1,0,1) For(z2,0,1) w[z1][z2]+=wl[z1][1]*wr[1][z2];
return w;
};
// printf("start work\n"); ffo; //E
auto w=work(0,sz(h)-1,work);
// printf("finish work\n"); ffo; //E
// printf("p=%d\n",p); // E
dp[p][0]=w[0][0]+w[0][1];
dp[p][1]=w[1][0]+w[1][1];
// printf("* "); for(auto w:dp[p][0]) printf("%lld ",w); printf("\n"); // E
// printf("* "); for(auto w:dp[p][1]) printf("%lld ",w); printf("\n"); // E
for(auto i:h) b[i]=false;
}
auto t=dp[rt][0]+dp[rt][1];
// printf("t: "); for(auto i:t) printf("%lld ",i); printf("\n"); // E
return t;
}
// void dfs(int p,int fa)
// {
// dp[p][0]={1ll},dp[p][1]={};
// for(auto i:v[p]) if(i!=fa)
// {
// dfs(i,p);
// dp[i][1]+=dp[i][0];
// dp[p][1]=(dp[p][0]*dp[i][0]*(vct<LL>){0ll,1ll})+(dp[p][1]*dp[i][1]);
// dp[p][0]*=dp[i][1];
// }
// // printf("dfs %d(%d):\n",p,fa);
// // for(auto i:dp[p][0]) printf("%lld ",i); printf("\n");
// // for(auto i:dp[p][1]) printf("%lld ",i); printf("\n");
// }
// vct <LL> work_dp(int avd,int rt)
// {
// dfs(rt,avd);
// return (dp[rt][0]+dp[rt][1]);
// }
vct <LL> pw_1sx(int m)
{
vct <LL> t(m+1,0ll);
For(i,0,m)
{
t[i]=C(m,i);
if(i&1) t[i]=mod-t[i];
}
return t;
}
vct <LL> work_again(vct <LL> t,int m)
{
// vct <LL> res={};
// int len=sz(t);
// For_(i,0,len)
// {
// vct <LL> w={(i&1)?md(-t[i]):t[i]};
// w*=pw_1sx(m-(i<<1));
// w>>=(i<<1);
// res+=w;
// }
// return res;
int len=sz(t); for(int i=1;i<len;i+=2) t[i]=(t[i]?(mod-t[i]):0ll);
auto work=[&](int l,int r,auto &self)->vct<LL>
{
if(l==r) return {t[l]};
int mid=l+r>>1;
auto t1=self(l,mid,self),t2=self(mid+1,r,self);
t1*=pw_1sx((r-mid)<<1); t2>>=((mid-l+1)<<1);
return (t1+t2);
};
auto res=work(0,len-1,work);
res*=pw_1sx(m-((len-1)<<1));
return res;
}
void main_solve()
{
inint(n); inll(step);
inpr(stt,ed);
For_(i,1,n)
{
int x,y; inpr(x,y);
v[x].pb(y),v[y].pb(x);
}
fac[0]=1ll; For(i,1,n) fac[i]=fac[i-1]*i%mod;
ifac[n]=inv(fac[n]); Rof(i,1,n) ifac[i-1]=ifac[i]*i%mod;
For(i,1,n) v_[i]=v[i];
vct <LL> d1=work_dp(0,1);
For(i,1,n) v[i].swap(v_[i]);
d1=work_again(d1,n);
dfs0(stt,0);
vct <int> pass;
for(int i=ed;i!=stt;i=up[i]) pass.pb(i); pass.pb(stt);
for(auto i:pass) b[i]=true;
vct <LL> d2={1ll};
for(auto i:pass) for(auto j:v[i]) if(!b[j]) d2*=work_dp(i,j);
d2=work_again(d2,n-sz(pass)); d2>>=(sz(pass)-1);
// printf("d2: "); for(auto i:d2) printf("%lld ",i); printf("\n");
// printf("d1: "); for(auto i:d1) printf("%lld ",i); printf("\n");
// return ;
outll(gt_div(step,d2,d1));
// For(i,0,10) printf("%lld ",gt_div(i,d2,d1));
}
int main()
{
// ios::sync_with_stdio(0); cin.tie(0);
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
// srand(time(NULL));
main_init();
// int _; inint(_); For(__,1,_) // T>1 ?
// printf("\n------------\n\n"),
main_solve();
// cerr<<clock()<<'\n';
return 0;
}
/*
d2: 0 998244352
d1: 1 998244351 2
-x
-2x+1
1-x -x
-x 1-x
1-2x
*/