結果
問題 | No.2587 Random Walk on Tree |
ユーザー | chineristAC |
提出日時 | 2023-12-16 00:33:26 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 7,376 ms / 10,000 ms |
コード長 | 9,039 bytes |
コンパイル時間 | 6,582 ms |
コンパイル使用メモリ | 245,632 KB |
実行使用メモリ | 70,544 KB |
最終ジャッジ日時 | 2024-09-27 13:31:37 |
合計ジャッジ時間 | 110,588 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,812 KB |
testcase_01 | AC | 2 ms
6,940 KB |
testcase_02 | AC | 2 ms
6,940 KB |
testcase_03 | AC | 3 ms
6,940 KB |
testcase_04 | AC | 3 ms
6,944 KB |
testcase_05 | AC | 5 ms
6,944 KB |
testcase_06 | AC | 4 ms
6,940 KB |
testcase_07 | AC | 2 ms
6,944 KB |
testcase_08 | AC | 4 ms
6,940 KB |
testcase_09 | AC | 3 ms
6,944 KB |
testcase_10 | AC | 4 ms
6,940 KB |
testcase_11 | AC | 11 ms
6,940 KB |
testcase_12 | AC | 51 ms
6,940 KB |
testcase_13 | AC | 59 ms
6,940 KB |
testcase_14 | AC | 21 ms
6,940 KB |
testcase_15 | AC | 4,437 ms
40,228 KB |
testcase_16 | AC | 2,288 ms
27,520 KB |
testcase_17 | AC | 2,668 ms
28,952 KB |
testcase_18 | AC | 306 ms
7,552 KB |
testcase_19 | AC | 4,958 ms
53,088 KB |
testcase_20 | AC | 4,290 ms
45,432 KB |
testcase_21 | AC | 5,371 ms
48,924 KB |
testcase_22 | AC | 4,522 ms
70,544 KB |
testcase_23 | AC | 4,993 ms
51,932 KB |
testcase_24 | AC | 4,507 ms
43,300 KB |
testcase_25 | AC | 3,237 ms
56,596 KB |
testcase_26 | AC | 7,376 ms
52,452 KB |
testcase_27 | AC | 6,172 ms
47,208 KB |
testcase_28 | AC | 6,466 ms
48,892 KB |
testcase_29 | AC | 6,383 ms
50,020 KB |
testcase_30 | AC | 5,260 ms
52,632 KB |
testcase_31 | AC | 5,163 ms
54,200 KB |
testcase_32 | AC | 6,299 ms
53,564 KB |
testcase_33 | AC | 2 ms
6,940 KB |
testcase_34 | AC | 1,880 ms
45,488 KB |
testcase_35 | AC | 1,879 ms
45,732 KB |
testcase_36 | AC | 1,880 ms
46,168 KB |
testcase_37 | AC | 2,999 ms
45,252 KB |
testcase_38 | AC | 4,415 ms
46,004 KB |
testcase_39 | AC | 4,491 ms
46,108 KB |
ソースコード
// -fsanitize=undefined, // #define _GLIBCXX_DEBUG #pragma GCC target("avx2") #pragma GCC optimize("unroll-loops") #include <iostream> #include <vector> #include <string> #include <map> #include <set> #include <queue> #include <algorithm> #include <cmath> #include <iomanip> #include <random> #include <stdio.h> #include <fstream> #include <functional> #include <cassert> #include <unordered_map> #include <bitset> #include <chrono> #include <atcoder/modint> #include <atcoder/convolution> using namespace std; using namespace atcoder; #define rep(i,n) for (int i=0;i<n;i+=1) #define rrep(i,n) for (int i=n-1;i>-1;i--) #define pb push_back #define all(x) (x).begin(), (x).end() #define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << " )\n"; template<class T> using vec = vector<T>; template<class T> using vvec = vec<vec<T>>; template<class T> using vvvec = vec<vvec<T>>; using ll = long long; using pii = pair<int,int>; using pll = pair<ll,ll>; template<class T> bool chmin(T &a, T b){ if (a>b){ a = b; return true; } return false; } template<class T> bool chmax(T &a, T b){ if (a<b){ a = b; return true; } return false; } template<class T> T sum(vec<T> x){ T res=0; for (auto e:x){ res += e; } return res; } template<class T> void printv(vec<T> x){ for (auto e:x){ cout<<e<<" "; } cout<<endl; } template<class T,class U> ostream& operator<<(ostream& os, const pair<T,U>& A){ os << "(" << A.first <<", " << A.second << ")"; return os; } template<class T> ostream& operator<<(ostream& os, const set<T>& S){ os << "set{"; for (auto a:S){ os << a; auto it = S.find(a); it++; if (it!=S.end()){ os << ", "; } } os << "}"; return os; } using mint = modint998244353; ostream& operator<<(ostream& os, const mint& a){ os << a.val(); return os; } template<class T> ostream& operator<<(ostream& os, const vec<T>& A){ os << "["; rep(i,A.size()){ os << A[i]; if (i!=A.size()-1){ os << ", "; } } os << "]" ; return os; } template <typename T> vector<T>& operator+=(vector<T>& a, const vector<T>& b) { if (a.size() < b.size()) { a.resize(b.size()); } for (int i = 0; i < (int)b.size(); i++) { a[i] += b[i]; } return a; } template <typename T> vector<T> operator+(const vector<T>& a, const vector<T>& b) { vector<T> c = a; return c += b; } vector<mint> calc_prod(vector<vector<mint>> polys){ if (polys.empty()) return {1}; deque<vec<mint>> deq; for (auto f:polys) deq.push_back(f); while (deq.size() > 1){ auto f = deq.front(); deq.pop_front(); auto g = deq.front(); deq.pop_front(); deq.push_back(convolution(f,g)); } return deq[0]; } pair<vector<mint>,vector<mint>> merge_child(vector<pair<vec<mint>,vec<mint>>> child_polys){ if (child_polys.empty()){ return {{1},{1,-1}}; } deque<pair<vec<mint>,vec<mint>>> deq; for (auto e:child_polys) deq.push_back(e); while (deq.size() > 1){ auto [f0,f1] = deq.front(); deq.pop_front(); auto [g0,g1] = deq.front(); deq.pop_front(); auto h1 = convolution(f1,g1); auto h0 = convolution(g0,f1) + convolution(f0,g1); deq.push_back({h0,h1}); } auto [A,B] = deq.front(); vector<mint> res0 = B; /*res1 = A * (-x^2) + B * (1-x) */ //debug(child_polys); //debug(A); //debug(B); vector<mint> res1(max(A.size()+2,B.size()+1),0); rep(i,A.size()) res1[i+2] -= A[i]; rep(i,B.size()) { res1[i] += B[i]; res1[i+1] -= B[i]; } return {res0,res1}; } pair<vector<mint>,vector<mint>> merge_path(vector<pair<vector<mint>,vector<mint>>> polys){ int n = polys.size(); if (n & 1){ polys.push_back({{0},{1}}); } n = polys.size(); //debug(polys); vector<array<vector<mint>,4>> nxt_polys; for (int i=0;i<n;i+=2){ /*0:空いてる 1:空いてない 0+0->00 or 11(-x^2倍) 0+1->01 1+0->10 1+1->11 */ auto [f0,f1] = polys[i]; auto [g0,g1] = polys[i+1]; array<vector<mint>,4> merge; merge[1] = convolution(f0,g1); merge[2] = convolution(f1,g0); merge[3] = convolution(f1,g1); auto h = convolution(f0,g0); merge[0] = h; int p = h.size(); if (merge[3].size() < p+2){ merge[3].resize(p+2); } rep(i,p) merge[3][i+2] -= h[i]; nxt_polys.push_back(merge); } auto calc = [&](auto self,int l,int r)->array<vector<mint>,4> { if (r-l==1){ return nxt_polys[l]; } int mid = (l+r)>>1; auto left = self(self,l,mid); auto right = self(self,mid,r); int merge_n = 0; rep(i,4){ rep(j,4){ int nxt_ocu = (i & 2) + (j & 1); int mid_ocu = (i & 1) + (j & 2); int h_deg = left[i].size() + right[j].size() - 2; //debug(i); //debug(j); //debug(h_deg); if (mid_ocu == 0){ chmax(merge_n,h_deg+2); } else if (mid_ocu == 3){ chmax(merge_n,h_deg); } } } merge_n += 1; //debug(merge_n); array<vector<mint>,4> merge; rep(i,4) merge[i] = vector<mint>(merge_n,0); rep(i,4){ rep(j,4){ int nxt_ocu = (i & 2) + (j & 1); int mid_ocu = (i & 1) + (j & 2); if (mid_ocu!=0 && mid_ocu!=3) continue; auto h = convolution(left[i],right[j]); //assert ((h.size()-1) == left[i].size() + right[j].size() - 2); if (mid_ocu == 0){ //assert (h.size()+1 < merge_n); rep(k,h.size()){ merge[nxt_ocu][k+2] -= h[k]; } } else if (mid_ocu == 3){ //debug(i); //debug(j); //debug(h.size()); //assert (h.size()-1 < merge_n); rep(k,h.size()){ merge[nxt_ocu][k] += h[k]; } } } } return merge; }; auto merged = calc(calc,0,nxt_polys.size()); vector<mint> res0 = merged[1], res1 = merged[3]; return {res0,res1}; } mint BostanMori2(vec<mint> P,vec<mint> Q, long long N){ for (; N; N>>= 1){ vec<mint> Q_neg = {all(Q)}; for (int i=1;i<Q_neg.size();i+=2) Q_neg[i] *= -1; P = convolution(P,Q_neg); Q = convolution(Q,Q_neg); int deg_P = P.size() - 1, deg_Q = Q.size() - 1; vec<mint> nxt_P((deg_P/2+1),0),nxt_Q((deg_Q/2+1),0); for (int i = (N & 1);i <= deg_P; i+=2){ nxt_P[i>>1] = P[i]; } for (int i = 0; i <= deg_Q; i+=2){ nxt_Q[i>>1] = Q[i]; } swap(P,nxt_P); swap(Q,nxt_Q); } return P[0]; } void solve(){ int N,M,S,T; cin>>N>>M>>S>>T; S--; T--; vec<vec<int>> edge(N); rep(i,N-1){ int a,b; cin>>a>>b; a--; b--; edge[a].push_back(b); edge[b].push_back(a); } vec<int> parent(N,-1), sz(N,1); vec<int> heavy_child(N,-1); int root = T; auto dfs0 = [&](auto self,int v,int pv)->void { int tmp_max_sz = 0; for (auto nv:edge[v]){ if (nv == pv) continue; parent[nv] = v; self(self,nv,v); sz[v] += sz[nv]; if (sz[nv] > tmp_max_sz){ tmp_max_sz = sz[nv]; heavy_child[v] = nv; } } }; dfs0(dfs0,root,-1); vector<int> on_st_path(N,0); on_st_path[T] = 1; { int pos = S; while (pos!=T){ on_st_path[pos] = 1; pos = parent[pos]; } } vector<int> to_parent_is_light(N,1); rep(v,N){ if (heavy_child[v]!=-1){ to_parent_is_light[heavy_child[v]] = 0; } } rep(v,N){ if (on_st_path[v]) to_parent_is_light[v] = 1; } vec<pair<vec<mint>,vec<mint>>> dp(N); auto dfs1 = [&](auto self,int v,int pv)->void { vector<pair<vec<mint>,vec<mint>>> child_polys; for (auto nv:edge[v]){ if (nv == parent[v] || on_st_path[nv]) continue; self(self,nv,v); if (nv!=heavy_child[v]){ child_polys.push_back(dp[nv]); dp[nv].first.clear(); dp[nv].second.clear(); } } dp[v] = merge_child(child_polys); if (to_parent_is_light[v]){ vector<pair<vec<mint>,vec<mint>>> path_polys = {dp[v]}; int pos = heavy_child[v]; while (pos!=-1 && !on_st_path[pos]){ path_polys.push_back(dp[pos]); pos = heavy_child[pos]; } dp[v] = merge_path(path_polys); } }; rep(v,N){ if (on_st_path[v]) dfs1(dfs1,v,-1); } vector<pair<vec<mint>,vec<mint>>> st_path_polys; vector<vec<mint>> st_path_polys_non; { int pos = S; while (pos!=-1){ st_path_polys.push_back(dp[pos]); st_path_polys_non.push_back(dp[pos].first); pos = parent[pos]; } } //debug(st_path_polys); //debug(st_path_polys_non); vector<mint> q_det = merge_path(st_path_polys).second; vector<mint> p_det = calc_prod(st_path_polys_non); //debug(p_det); //debug(q_det); int D = int(st_path_polys.size()) - 1; if (M < D){ cout << 0 << "\n"; return ; } auto ans = BostanMori2(p_det,q_det,M-D); cout << ans << "\n"; } int main(){ ios::sync_with_stdio(false); std::cin.tie(nullptr); int T = 1; while (T--){ solve(); } }