結果
問題 | No.2587 Random Walk on Tree |
ユーザー |
|
提出日時 | 2023-12-16 00:33:26 |
言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 7,263 ms / 10,000 ms |
コード長 | 9,039 bytes |
コンパイル時間 | 23,253 ms |
コンパイル使用メモリ | 291,048 KB |
最終ジャッジ日時 | 2025-02-18 11:43:57 |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 37 |
ソースコード
// -fsanitize=undefined,// #define _GLIBCXX_DEBUG#pragma GCC target("avx2")#pragma GCC optimize("unroll-loops")#include <iostream>#include <vector>#include <string>#include <map>#include <set>#include <queue>#include <algorithm>#include <cmath>#include <iomanip>#include <random>#include <stdio.h>#include <fstream>#include <functional>#include <cassert>#include <unordered_map>#include <bitset>#include <chrono>#include <atcoder/modint>#include <atcoder/convolution>using namespace std;using namespace atcoder;#define rep(i,n) for (int i=0;i<n;i+=1)#define rrep(i,n) for (int i=n-1;i>-1;i--)#define pb push_back#define all(x) (x).begin(), (x).end()#define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << " )\n";template<class T>using vec = vector<T>;template<class T>using vvec = vec<vec<T>>;template<class T>using vvvec = vec<vvec<T>>;using ll = long long;using pii = pair<int,int>;using pll = pair<ll,ll>;template<class T>bool chmin(T &a, T b){if (a>b){a = b;return true;}return false;}template<class T>bool chmax(T &a, T b){if (a<b){a = b;return true;}return false;}template<class T>T sum(vec<T> x){T res=0;for (auto e:x){res += e;}return res;}template<class T>void printv(vec<T> x){for (auto e:x){cout<<e<<" ";}cout<<endl;}template<class T,class U>ostream& operator<<(ostream& os, const pair<T,U>& A){os << "(" << A.first <<", " << A.second << ")";return os;}template<class T>ostream& operator<<(ostream& os, const set<T>& S){os << "set{";for (auto a:S){os << a;auto it = S.find(a);it++;if (it!=S.end()){os << ", ";}}os << "}";return os;}using mint = modint998244353;ostream& operator<<(ostream& os, const mint& a){os << a.val();return os;}template<class T>ostream& operator<<(ostream& os, const vec<T>& A){os << "[";rep(i,A.size()){os << A[i];if (i!=A.size()-1){os << ", ";}}os << "]" ;return os;}template <typename T> vector<T>& operator+=(vector<T>& a, const vector<T>& b) {if (a.size() < b.size()) {a.resize(b.size());}for (int i = 0; i < (int)b.size(); i++) {a[i] += b[i];}return a;}template <typename T> vector<T> operator+(const vector<T>& a, const vector<T>& b) {vector<T> c = a;return c += b;}vector<mint> calc_prod(vector<vector<mint>> polys){if (polys.empty()) return {1};deque<vec<mint>> deq;for (auto f:polys) deq.push_back(f);while (deq.size() > 1){auto f = deq.front(); deq.pop_front();auto g = deq.front(); deq.pop_front();deq.push_back(convolution(f,g));}return deq[0];}pair<vector<mint>,vector<mint>> merge_child(vector<pair<vec<mint>,vec<mint>>> child_polys){if (child_polys.empty()){return {{1},{1,-1}};}deque<pair<vec<mint>,vec<mint>>> deq;for (auto e:child_polys) deq.push_back(e);while (deq.size() > 1){auto [f0,f1] = deq.front(); deq.pop_front();auto [g0,g1] = deq.front(); deq.pop_front();auto h1 = convolution(f1,g1);auto h0 = convolution(g0,f1) + convolution(f0,g1);deq.push_back({h0,h1});}auto [A,B] = deq.front();vector<mint> res0 = B;/*res1 = A * (-x^2) + B * (1-x) *///debug(child_polys);//debug(A);//debug(B);vector<mint> res1(max(A.size()+2,B.size()+1),0);rep(i,A.size()) res1[i+2] -= A[i];rep(i,B.size()) {res1[i] += B[i];res1[i+1] -= B[i];}return {res0,res1};}pair<vector<mint>,vector<mint>> merge_path(vector<pair<vector<mint>,vector<mint>>> polys){int n = polys.size();if (n & 1){polys.push_back({{0},{1}});}n = polys.size();//debug(polys);vector<array<vector<mint>,4>> nxt_polys;for (int i=0;i<n;i+=2){/*0:空いてる 1:空いてない0+0->00 or 11(-x^2倍)0+1->011+0->101+1->11*/auto [f0,f1] = polys[i];auto [g0,g1] = polys[i+1];array<vector<mint>,4> merge;merge[1] = convolution(f0,g1);merge[2] = convolution(f1,g0);merge[3] = convolution(f1,g1);auto h = convolution(f0,g0);merge[0] = h;int p = h.size();if (merge[3].size() < p+2){merge[3].resize(p+2);}rep(i,p) merge[3][i+2] -= h[i];nxt_polys.push_back(merge);}auto calc = [&](auto self,int l,int r)->array<vector<mint>,4> {if (r-l==1){return nxt_polys[l];}int mid = (l+r)>>1;auto left = self(self,l,mid);auto right = self(self,mid,r);int merge_n = 0;rep(i,4){rep(j,4){int nxt_ocu = (i & 2) + (j & 1);int mid_ocu = (i & 1) + (j & 2);int h_deg = left[i].size() + right[j].size() - 2;//debug(i);//debug(j);//debug(h_deg);if (mid_ocu == 0){chmax(merge_n,h_deg+2);}else if (mid_ocu == 3){chmax(merge_n,h_deg);}}}merge_n += 1;//debug(merge_n);array<vector<mint>,4> merge;rep(i,4) merge[i] = vector<mint>(merge_n,0);rep(i,4){rep(j,4){int nxt_ocu = (i & 2) + (j & 1);int mid_ocu = (i & 1) + (j & 2);if (mid_ocu!=0 && mid_ocu!=3) continue;auto h = convolution(left[i],right[j]);//assert ((h.size()-1) == left[i].size() + right[j].size() - 2);if (mid_ocu == 0){//assert (h.size()+1 < merge_n);rep(k,h.size()){merge[nxt_ocu][k+2] -= h[k];}}else if (mid_ocu == 3){//debug(i);//debug(j);//debug(h.size());//assert (h.size()-1 < merge_n);rep(k,h.size()){merge[nxt_ocu][k] += h[k];}}}}return merge;};auto merged = calc(calc,0,nxt_polys.size());vector<mint> res0 = merged[1], res1 = merged[3];return {res0,res1};}mint BostanMori2(vec<mint> P,vec<mint> Q, long long N){for (; N; N>>= 1){vec<mint> Q_neg = {all(Q)};for (int i=1;i<Q_neg.size();i+=2) Q_neg[i] *= -1;P = convolution(P,Q_neg);Q = convolution(Q,Q_neg);int deg_P = P.size() - 1, deg_Q = Q.size() - 1;vec<mint> nxt_P((deg_P/2+1),0),nxt_Q((deg_Q/2+1),0);for (int i = (N & 1);i <= deg_P; i+=2){nxt_P[i>>1] = P[i];}for (int i = 0; i <= deg_Q; i+=2){nxt_Q[i>>1] = Q[i];}swap(P,nxt_P);swap(Q,nxt_Q);}return P[0];}void solve(){int N,M,S,T;cin>>N>>M>>S>>T;S--; T--;vec<vec<int>> edge(N);rep(i,N-1){int a,b;cin>>a>>b;a--; b--;edge[a].push_back(b);edge[b].push_back(a);}vec<int> parent(N,-1), sz(N,1);vec<int> heavy_child(N,-1);int root = T;auto dfs0 = [&](auto self,int v,int pv)->void {int tmp_max_sz = 0;for (auto nv:edge[v]){if (nv == pv) continue;parent[nv] = v;self(self,nv,v);sz[v] += sz[nv];if (sz[nv] > tmp_max_sz){tmp_max_sz = sz[nv];heavy_child[v] = nv;}}};dfs0(dfs0,root,-1);vector<int> on_st_path(N,0);on_st_path[T] = 1;{int pos = S;while (pos!=T){on_st_path[pos] = 1;pos = parent[pos];}}vector<int> to_parent_is_light(N,1);rep(v,N){if (heavy_child[v]!=-1){to_parent_is_light[heavy_child[v]] = 0;}}rep(v,N){if (on_st_path[v]) to_parent_is_light[v] = 1;}vec<pair<vec<mint>,vec<mint>>> dp(N);auto dfs1 = [&](auto self,int v,int pv)->void {vector<pair<vec<mint>,vec<mint>>> child_polys;for (auto nv:edge[v]){if (nv == parent[v] || on_st_path[nv]) continue;self(self,nv,v);if (nv!=heavy_child[v]){child_polys.push_back(dp[nv]);dp[nv].first.clear();dp[nv].second.clear();}}dp[v] = merge_child(child_polys);if (to_parent_is_light[v]){vector<pair<vec<mint>,vec<mint>>> path_polys = {dp[v]};int pos = heavy_child[v];while (pos!=-1 && !on_st_path[pos]){path_polys.push_back(dp[pos]);pos = heavy_child[pos];}dp[v] = merge_path(path_polys);}};rep(v,N){if (on_st_path[v]) dfs1(dfs1,v,-1);}vector<pair<vec<mint>,vec<mint>>> st_path_polys;vector<vec<mint>> st_path_polys_non;{int pos = S;while (pos!=-1){st_path_polys.push_back(dp[pos]);st_path_polys_non.push_back(dp[pos].first);pos = parent[pos];}}//debug(st_path_polys);//debug(st_path_polys_non);vector<mint> q_det = merge_path(st_path_polys).second;vector<mint> p_det = calc_prod(st_path_polys_non);//debug(p_det);//debug(q_det);int D = int(st_path_polys.size()) - 1;if (M < D){cout << 0 << "\n";return ;}auto ans = BostanMori2(p_det,q_det,M-D);cout << ans << "\n";}int main(){ios::sync_with_stdio(false);std::cin.tie(nullptr);int T = 1;while (T--){solve();}}