結果

問題 No.2587 Random Walk on Tree
ユーザー XY-Eleven
提出日時 2025-06-05 18:36:02
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 4,141 ms / 10,000 ms
コード長 12,019 bytes
コンパイル時間 3,046 ms
コンパイル使用メモリ 223,540 KB
実行使用メモリ 87,104 KB
最終ジャッジ日時 2025-06-05 18:37:36
合計ジャッジ時間 84,351 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 37
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘void main_solve()’:
main.cpp:17:23: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   17 | #define inint(e) scanf("%d",&e)
      |                  ~~~~~^~~~~~~~~
main.cpp:389:9: note: in expansion of macro ‘inint’
  389 |         inint(n); inll(step);
      |         ^~~~~
main.cpp:18:22: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   18 | #define inll(e) scanf("%lld",&e)
      |                 ~~~~~^~~~~~~~~~~
main.cpp:389:19: note: in expansion of macro ‘inll’
  389 |         inint(n); inll(step);
      |                   ^~~~
main.cpp:19:26: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   19 | #define inpr(e1,e2) scanf("%d%d",&e1,&e2)
      |                     ~~~~~^~~~~~~~~~~~~~~~
main.cpp:390:9: note: in expansion of macro ‘inpr’
  390 |         inpr(stt,ed);
      |         ^~~~
main.cpp:19:26: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   19 | #define inpr(e1,e2) scanf("%d%d",&e1,&e2)
      |                     ~~~~~^~~~~~~~~~~~~~~~
main.cpp:393:26: note: in expansion of macro ‘inpr’
  393 |                 int x,y; inpr(x,y);
      |                          ^~~~

ソースコード

diff #

#include <bits/stdc++.h>
// #include <windows.h>
// #include <bits/extc++.h>
// using namespace __gnu_pbds;
using namespace std;
//#pragma GCC optimize(3)
#define DB double
#define LL long long
#define ULL unsigned long long
#define in128 __int128
#define cint const int
#define cLL const LL
#define For(z,e1,e2) for(int z=(e1);z<=(e2);z++)
#define Rof(z,e1,e2) for(int z=(e2);z>=(e1);z--)
#define For_(z,e1,e2) for(int z=(e1);z<(e2);z++)
#define Rof_(z,e1,e2) for(int z=(e2);z>(e1);z--)
#define inint(e) scanf("%d",&e)
#define inll(e) scanf("%lld",&e)
#define inpr(e1,e2) scanf("%d%d",&e1,&e2)
#define in3(e1,e2,e3) scanf("%d%d%d",&e1,&e2,&e3)
#define outint(e) printf("%d\n",e)
#define outint_(e) printf("%d%c",e," \n"[i==n])
#define outint2_(e,e1,e2) printf("%d%c",e," \n"[(e1)==(e2)])
#define outll(e) printf("%lld\n",e)
#define outll_(e) printf("%lld%c",e," \n"[i==n])
#define outll2_(e,e1,e2) printf("%lld%c",e," \n"[(e1)==(e2)])
#define exc(e) if(e) continue
#define stop(e) if(e) break
#define ret(e) if(e) return
#define pb push_back
#define ft first
#define sc second
#define pii pair<int,int> 
#define pli pair<long long,int> 
#define vct vector 
#define clean(e) while(!e.empty()) e.pop()
#define all(ev) ev.begin(),ev.end()
#define sz(ev) ((int)ev.size())
#define debug(x) printf("%s=%d\n",#x,x)
#define x0 __xx00__
#define y0 __yy00__
#define y1 __yy11__
#define ffo fflush(stdout)
cLL mod=998244353ll,G=404ll;
// cLL mod=1000000007ll;
// cLL mod[2]={1686688681ll,1666888681ll},base[2]={166686661ll,188868881ll};
template <typename Type> void get_min(Type &w1,const Type w2) { if(w2<w1) w1=w2; } template <typename Type> void get_max(Type &w1,const Type w2) { if(w2>w1) w1=w2; }
template <typename Type> Type up_div(Type w1,Type w2) { return (w1/w2+(w1%w2?1:0)); }
template <typename Type> Type gcd(Type X_,Type Y_) { Type R_=X_%Y_; while(R_) { X_=Y_; Y_=R_; R_=X_%Y_; } return Y_; } template <typename Type> Type lcm(Type X_,Type Y_) { return (X_/gcd(X_,Y_)*Y_); }
template <typename Type> Type md(Type w1,const Type w2=mod) { w1%=w2; if(w1<0) w1+=w2; return w1; } template <typename Type> Type md_(Type w1,const Type w2=mod) { w1%=w2; if(w1<=0) w1+=w2; return w1; }
void ex_gcd(LL &X_,LL &Y_,LL A_,LL B_) { if(!B_) { X_=1ll; Y_=0ll; return ; } ex_gcd(Y_,X_,B_,A_%B_); X_=md(X_,B_); Y_=(1ll-X_*A_)/B_; } LL inv(LL A_,LL B_=mod) { LL X_=0ll,Y_=0ll; ex_gcd(X_,Y_,A_,B_); return X_; }
template <typename Type> void add(Type &w1,const Type w2,const Type M_=mod) { w1=md(w1+w2,M_); } void mul(LL &w1,cLL w2,cLL M_=mod) { w1=md(w1*md(w2,M_),M_); } template <typename Type> Type pw(Type X_,Type Y_,Type M_=mod) { Type S_=1; while(Y_) { if(Y_&1) mul(S_,X_,M_); Y_>>=1; mul(X_,X_,M_); } return S_; }
// template <typename Type> Type bk(vector <Type> &V_) { auto T_=V_.back(); V_.pop_back(); return T_; } template <typename Type> Type tp(stack <Type> &V_) { auto T_=V_.top(); V_.pop(); return T_; } template <typename Type> Type frt(queue <Type> &V_) { auto T_=V_.front(); V_.pop(); return T_; }
// template <typename Type> Type bg(set <Type> &V_) { auto T_=*V_.begin(); V_.erase(V_.begin()); return T_; } template <typename Type> Type bk(set <Type> &V_) { auto T_=*prev(V_.end()); V_.erase(*prev(V_.end())); return T_; }
mt19937 gen(time(NULL)); int rd() { return abs((int)gen()); }
int rnd(int l,int r) { return rd()%(r-l+1)+l; }

void Add(LL &w1,cLL w2)
{
	if((w1+=w2)>=mod) w1-=mod;
}

cint H=21,L=1<<H|11,BL=20;
vct <LL> z[H+1];
void main_init()
{
	For(i,0,H)
	{
		z[i].resize(1<<i);
		LL s=1ll,t=pw(G,(mod-1ll)>>i);
		For_(j,0,1<<i) z[i][j]=s,mul(s,t);
	}
}
int minsz(int w)
{
	return (__lg(w)+(w!=(w&-w)));
}
void NTT(vct <LL> &v,bool op)
{
	int d=minsz(sz(v)),l=1<<d;
	v.resize(l,0ll);
	vct <int> rev(l,0);
	For_(k,1,l)
	{
		rev[k]=(rev[k>>1]>>1)|((k&1)<<d-1);
		if(rev[k]<k) swap(v[k],v[rev[k]]);
	}
	For_(k,0,d)
	{
		for(int i=0;i<l;i+=(1<<k+1)) For_(j,i,i|(1<<k))
		{
			int j2=(j|(1<<k));
			LL t=v[j2]*z[k+1][i^j]%mod;
			v[j2]=v[j]-t; if(v[j2]<0ll) v[j2]+=mod;
			v[j]+=t; if(v[j]>=mod) v[j]-=mod;
		}
	}
	ret(!op);
	LL t=inv(1ll*l);
	for(auto &i:v) mul(i,t);
	reverse(v.begin()+1,v.end());
}
vct <LL> operator * (vct <LL> w1,vct <LL> w2)
{
	if(w1.empty()||w2.empty()) return {};
	int len1=sz(w1),len2=sz(w2);
	int len=len1+len2-1,l=1<<minsz(len);
	if(min(len1,len2)<=BL)
	{
		vct <LL> w(len,0ll);
		For_(i,0,len1) For_(j,0,len2)
			add(w[i+j],w1[i]*w2[j]);
		return w;
	}
	w1.resize(l,0ll),w2.resize(l,0ll);
	NTT(w1,false),NTT(w2,false);
	For_(i,0,l) mul(w1[i],w2[i]);
	NTT(w1,true);
	w1.resize(len);
	return w1;
}
void operator *= (vct <LL> &w1,vct <LL> w2)
{
	w1=w1*w2;
}
vct <LL> operator + (vct <LL> w1,vct <LL> w2)
{
	int len1=sz(w1),len2=sz(w2),len=min(len1,len2);
	For_(i,0,len) Add(w1[i],w2[i]);
	if(len1<len2)
	{
		w1.resize(len2,0ll);
		For_(i,len1,len2) w1[i]=w2[i];
	}
	return w1;
}
void operator += (vct <LL> &w1,vct <LL> w2)
{
	w1=w1+w2;
}

void operator <<= (vct <LL> &w,int t)
{
	w.resize(sz(w)+t);
}
void operator >>= (vct <LL> &w,int t)
{
	reverse(all(w));
	w<<=t;
	reverse(all(w));
}
vct <LL> keep(vct <LL> &w,int op)
{
	int len=sz(w);
	vct <LL> w_;
	for(int i=op;i<len;i+=2) w_.pb(w[i]);
	return w_;
}
LL gt_div(LL t,vct <LL> w1,vct <LL> w2)
{
	reverse(all(w2)); while(w2.back()==0ll) t++,w2.pop_back(); reverse(all(w2));
	while(t)
	{
		auto w=w2; int len=sz(w);
		for(int i=1;i<len;i+=2) w[i]=md(-w[i]);
		w1*=w,w2*=w;
		w2=keep(w2,0),w1=keep(w1,(int)(t&1ll));
		t>>=1;
	}
	if(w1.empty()||w2.empty()) return 0ll;
	return (w1[0]*inv(w2[0])%mod);
}
// LL cal_M_th(LL t,vct <LL> w1,vct <LL> w2)
// {
// 	for(auto &i:w1) i=md(i);
// 	if(t<sz(w1)) return w1[t];
// 	for(auto &i:w2) i=md(-i);
// 	w2[0]=1ll; w1*=w2; w1.resize(sz(w2)-1,0ll);
// 	return gt_div(t,w1,w2);
// }
/* vct <LL> BM(vct <LL> a)
{
	LL lst_del=0ll,del=0ll;
	vct <LL> lst,f;
	int len=sz(a),tot=0,k=-1;
	For_(i,0,len)
	{
		LL w=0ll;
		For_(j,0,min(tot,i-1)) add(w,a[i-j-1]*f[j]);
		exc(w==a[i]);
		LL del=md(a[i]-w);
		if(!~k)
		{
			k=i,lst_del=del;
			f.resize(i+1,0ll);
			continue;
		}
		auto f2=f;
		int tot2=sz(lst);
		if(tot<i-k+tot2)
		{
			tot=i-k+tot2;
			f.resize(tot,0ll);
		}
		LL t=del*inv(lst_del)%mod;
		add(f[i-k-1],t);
		For_(j,0,tot2) add(f[i-k+j],-lst[j]*t);
		if(i-sz(f2)>k-tot2)
			k=i,lst_del=del,lst=f2;
	}
	return f;
} */
cint N=1.02e5;
int n; LL step;
int stt,ed;
vct <int> v[N],v_[N];
int up[N]; bool b[N];
LL fac[N],ifac[N];
LL C(int n_,int m_)
{
	if(n_<m_||m_<0) return 0ll;
	return (fac[n_]*(ifac[n_-m_]*ifac[m_]%mod)%mod);
}
void dfs0(int p,int fa)
{
	up[p]=fa;
	for(auto i:v[p]) if(i!=fa)
		dfs0(i,p);
}
int siz[N],siz2[N];
vct <int> pas;
void dfs01(int p,int fa)
{
	pas.pb(p),up[p]=fa;
	if(fa) v[p].erase(find(all(v[p]),fa));
	siz[p]=1;
	ret(v[p].empty());
	int len=sz(v[p]),k=0;
	For_(i_,0,len)
	{
		int i=v[p][i_];
		// printf("%d->%d\n",p,i); ffo; //E
		dfs01(i,p);
		if(siz[i]>siz[v[p][k]]) k=i_;
		siz[p]+=siz[i];
	}
	swap(v[p][0],v[p][k]);
}
int tp[N];
void dfs02(int p,bool f)
{
	// printf("p=%d\n",p); ffo; //E
	tp[p]=(f?p:tp[up[p]]);
	for(auto i:v[p]) dfs02(i,i!=v[p][0]);
}
array <vct<LL>,2> dp[N];
vct <LL> work_dp(int avd,int rt)
{
	// printf("work_dp %d(%d)\n",rt,avd); ffo; // E
	pas.clear();
	dfs01(rt,avd);
	// printf("dfs01 end\n"); ffo; //E
	dfs02(rt,true);
	// printf("dfs02 end\n"); ffo; //E
	reverse(all(pas));
	for(auto p:pas) if(tp[p]==p)
	{
		// printf("p=%d\n",p); ffo; // E
		vct <int> h;
		for(int i=p;;i=v[i][0])
		{
			h.pb(i),b[i]=true;
			stop(v[i].empty());
		}
		// printf("line\n> "); // E
		// for(auto i:h) printf("%d ",i); printf("\n"); // E
		for(auto i:h)
		{
			vct <int> v2;
			for(auto j:v[i]) if(!b[j]) v2.pb(j);
			// printf("init of i=%d\n",i); ffo; // E
			siz2[i]=1;
			if(v2.empty())
				dp[i][0]={1ll},dp[i][1]={};
			else
			{
				for(auto j:v2) dp[j][1]+=dp[j][0],siz2[i]+=siz[j];
				auto work=[&](int l,int r,auto &self)->array <vct<LL>,2>
				{
					if(l==r) return {dp[v2[l]][1],dp[v2[l]][0]};
					int mid=l,sum=0,now=siz[v2[l]];
					For(k,l,r) sum+=siz[v2[k]];
					while(mid+1<r&&abs((now<<1)-sum)>abs((now+siz[v2[mid+1]]<<1)-sum)) now+=siz[v2[++mid]];
					auto [l0,l1]=self(l,mid,self);
					auto [r0,r1]=self(mid+1,r,self);
					return {l0*r0,(l0*r1)+(l1*r0)};
				};
				dp[i]=work(0,sz(v2)-1,work); dp[i][1]>>=1;
			}
			// printf("* "); for(auto w:dp[i][0]) printf("%lld ",w); printf("\n"); // E
			// printf("* "); for(auto w:dp[i][1]) printf("%lld ",w); printf("\n"); // E
			// printf("finish\n",i); ffo; //E
		}
		auto work=[&](int l,int r,auto &self)->array <array<vct<LL>,2>,2>
		{
			if(l==r) return (array<array<vct<LL>,2>,2>){(array<vct<LL>,2>){dp[h[l]][0],(vct<LL>){}},(array<vct<LL>,2>){(vct<LL>){},dp[h[l]][1]}};
			int mid=l,sum=0,now=siz2[h[l]];
			For(k,l,r) sum+=siz2[h[k]];
			while(mid+1<r&&abs((now<<1)-sum)>abs((now+siz2[h[mid+1]]<<1)-sum)) now+=siz2[h[++mid]];
			auto wl=self(l,mid,self);
			auto wr=self(mid+1,r,self);
			array <array<vct<LL>,2>,2> w;
			For(z1,0,1) For(z2,0,1) w[z1|(l==mid)][z2|(mid+1==r)]+=wl[z1][0]*wr[0][z2];
			For(z1,0,1) For(z2,0,1) if(!w[z1][z2].empty()) w[z1][z2]>>=1;
			For(z,0,1) wl[z][1]+=wl[z][0],wr[1][z]+=wr[0][z];
			For(z1,0,1) For(z2,0,1) w[z1][z2]+=wl[z1][1]*wr[1][z2];
			return w;
		};
		// printf("start work\n"); ffo; //E
		auto w=work(0,sz(h)-1,work);
		// printf("finish work\n"); ffo; //E
		// printf("p=%d\n",p); // E
		dp[p][0]=w[0][0]+w[0][1];
		dp[p][1]=w[1][0]+w[1][1];
		// printf("* "); for(auto w:dp[p][0]) printf("%lld ",w); printf("\n"); // E
		// printf("* "); for(auto w:dp[p][1]) printf("%lld ",w); printf("\n"); // E
		for(auto i:h) b[i]=false;
	}
	auto t=dp[rt][0]+dp[rt][1];
	// printf("t: "); for(auto i:t) printf("%lld ",i); printf("\n"); // E
	return t;
}
// void dfs(int p,int fa)
// {
// 	dp[p][0]={1ll},dp[p][1]={};
// 	for(auto i:v[p]) if(i!=fa)
// 	{
// 		dfs(i,p);
// 		dp[i][1]+=dp[i][0];
// 		dp[p][1]=(dp[p][0]*dp[i][0]*(vct<LL>){0ll,1ll})+(dp[p][1]*dp[i][1]);
// 		dp[p][0]*=dp[i][1];
// 	}
// 	// printf("dfs %d(%d):\n",p,fa);
// 	// for(auto i:dp[p][0]) printf("%lld ",i); printf("\n");
// 	// for(auto i:dp[p][1]) printf("%lld ",i); printf("\n");
// }
// vct <LL> work_dp(int avd,int rt)
// {
// 	dfs(rt,avd);
// 	return (dp[rt][0]+dp[rt][1]);
// }
vct <LL> pw_1sx(int m)
{
	vct <LL> t(m+1,0ll);
	For(i,0,m)
	{
		t[i]=C(m,i);
		if(i&1) t[i]=mod-t[i];
	}
	return t;
}
vct <LL> work_again(vct <LL> t,int m)
{
	// vct <LL> res={};
	// int len=sz(t);
	// For_(i,0,len)
	// {
	// 	vct <LL> w={(i&1)?md(-t[i]):t[i]};
	// 	w*=pw_1sx(m-(i<<1));
	// 	w>>=(i<<1);
	// 	res+=w;
	// }
	// return res;
	int len=sz(t); for(int i=1;i<len;i+=2) t[i]=(t[i]?(mod-t[i]):0ll);
	auto work=[&](int l,int r,auto &self)->vct<LL>
	{
		if(l==r) return {t[l]};
		int mid=l+r>>1;
		auto t1=self(l,mid,self),t2=self(mid+1,r,self);
		t1*=pw_1sx((r-mid)<<1); t2>>=((mid-l+1)<<1);
		return (t1+t2);
	};
	auto res=work(0,len-1,work);
	res*=pw_1sx(m-((len-1)<<1));
	return res;
}
void main_solve()
{
	inint(n); inll(step);
	inpr(stt,ed);
	For_(i,1,n)
	{
		int x,y; inpr(x,y);
		v[x].pb(y),v[y].pb(x);
	}
	fac[0]=1ll; For(i,1,n) fac[i]=fac[i-1]*i%mod;
	ifac[n]=inv(fac[n]); Rof(i,1,n) ifac[i-1]=ifac[i]*i%mod;
	For(i,1,n) v_[i]=v[i];
	vct <LL> d1=work_dp(0,1);
	For(i,1,n) v[i].swap(v_[i]);
	d1=work_again(d1,n);
	dfs0(stt,0);
	vct <int> pass;
	for(int i=ed;i!=stt;i=up[i]) pass.pb(i); pass.pb(stt);
	for(auto i:pass) b[i]=true;
	vct <LL> d2={1ll};
	for(auto i:pass) for(auto j:v[i]) if(!b[j]) d2*=work_dp(i,j);
	d2=work_again(d2,n-sz(pass)); d2>>=(sz(pass)-1);
	// printf("d2: "); for(auto i:d2) printf("%lld ",i); printf("\n");
	// printf("d1: "); for(auto i:d1) printf("%lld ",i); printf("\n");
	// return ;
	outll(gt_div(step,d2,d1));
	// For(i,0,10) printf("%lld ",gt_div(i,d2,d1));
}
int main()
{
	// ios::sync_with_stdio(0); cin.tie(0);
	// freopen("in.txt","r",stdin);
	// freopen("out.txt","w",stdout);
	// srand(time(NULL));
	main_init();
	// int _; inint(_); For(__,1,_) // T>1 ?
		// printf("\n------------\n\n"),
		main_solve();
	// cerr<<clock()<<'\n';
	return 0;
}
/*
d2: 0 998244352 
d1: 1 998244351 2 
-x
-2x+1

1-x  -x
 -x 1-x
1-2x
*/
0