結果
| 問題 | No.2587 Random Walk on Tree | 
| コンテスト | |
| ユーザー |  | 
| 提出日時 | 2023-12-16 00:33:26 | 
| 言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.87.0) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 7,263 ms / 10,000 ms | 
| コード長 | 9,039 bytes | 
| コンパイル時間 | 23,253 ms | 
| コンパイル使用メモリ | 291,048 KB | 
| 最終ジャッジ日時 | 2025-02-18 11:43:57 | 
| ジャッジサーバーID (参考情報) | judge4 / judge5 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 37 | 
ソースコード
// -fsanitize=undefined,
// #define _GLIBCXX_DEBUG
#pragma GCC target("avx2")
#pragma GCC optimize("unroll-loops")
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <queue>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <random>
#include <stdio.h>
#include <fstream>
#include <functional>
#include <cassert>
#include <unordered_map>
#include <bitset>
#include <chrono>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using namespace atcoder;
#define rep(i,n) for (int i=0;i<n;i+=1)
#define rrep(i,n) for (int i=n-1;i>-1;i--)
#define pb push_back
#define all(x) (x).begin(), (x).end()
#define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << " )\n";
template<class T>
using vec = vector<T>;
template<class T>
using vvec = vec<vec<T>>;
template<class T>
using vvvec = vec<vvec<T>>;
using ll = long long;
using pii = pair<int,int>;
using pll = pair<ll,ll>;
template<class T>
bool chmin(T &a, T b){
  if (a>b){
    a = b;
    return true;
  }
  return false;
}
template<class T>
bool chmax(T &a, T b){
  if (a<b){
    a = b;
    return true;
  }
  return false;
}
template<class T>
T sum(vec<T> x){
  T res=0;
  for (auto e:x){
    res += e;
  }
  return res;
}
template<class T>
void printv(vec<T> x){
  for (auto e:x){
    cout<<e<<" ";
  }
  cout<<endl;
}
template<class T,class U>
ostream& operator<<(ostream& os, const pair<T,U>& A){
  os << "(" << A.first <<", " << A.second << ")";
  return os;
}
template<class T>
ostream& operator<<(ostream& os, const set<T>& S){
  os << "set{";
  for (auto a:S){
    os << a;
    auto it = S.find(a);
    it++;
    if (it!=S.end()){
      os << ", ";
    }
  }
  os << "}";
  return os;
}
using mint = modint998244353;
ostream& operator<<(ostream& os, const mint& a){
  os << a.val();
  return os;
}
template<class T>
ostream& operator<<(ostream& os, const vec<T>& A){
  os << "[";
  rep(i,A.size()){
    os << A[i];
    if (i!=A.size()-1){
      os << ", ";
    }
  }
  os << "]" ;
  return os;
}
template <typename T> vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
    if (a.size() < b.size()) {
        a.resize(b.size());
    }
    for (int i = 0; i < (int)b.size(); i++) {
        a[i] += b[i];
    }
    return a;
}
template <typename T> vector<T> operator+(const vector<T>& a, const vector<T>& b) {
    vector<T> c = a;
    return c += b;
}
vector<mint> calc_prod(vector<vector<mint>> polys){
  if (polys.empty()) return {1};
  deque<vec<mint>> deq;
  for (auto f:polys) deq.push_back(f);
  while (deq.size() > 1){
    auto f = deq.front(); deq.pop_front();
    auto g = deq.front(); deq.pop_front();
    deq.push_back(convolution(f,g));
  }
  return deq[0];
}
pair<vector<mint>,vector<mint>> merge_child(vector<pair<vec<mint>,vec<mint>>> child_polys){
  if (child_polys.empty()){
    return {{1},{1,-1}};
  }
  deque<pair<vec<mint>,vec<mint>>> deq;
  for (auto e:child_polys) deq.push_back(e);
  
  while (deq.size() > 1){
    auto [f0,f1] = deq.front(); deq.pop_front();
    auto [g0,g1] = deq.front(); deq.pop_front();
    auto h1 = convolution(f1,g1);
    auto h0 = convolution(g0,f1) + convolution(f0,g1);
    deq.push_back({h0,h1});
  }
  auto [A,B] = deq.front();
  vector<mint> res0 = B;
  /*res1 = A * (-x^2) + B * (1-x) */
  //debug(child_polys);
  //debug(A);
  //debug(B);
  vector<mint> res1(max(A.size()+2,B.size()+1),0);
  rep(i,A.size()) res1[i+2] -= A[i];
  rep(i,B.size()) {
    res1[i] += B[i];
    res1[i+1] -= B[i];
  }
  return {res0,res1};
}
pair<vector<mint>,vector<mint>> merge_path(vector<pair<vector<mint>,vector<mint>>> polys){
  int n = polys.size();
  if (n & 1){
    polys.push_back({{0},{1}});
  }
  n = polys.size();
  //debug(polys);
  vector<array<vector<mint>,4>> nxt_polys;
  for (int i=0;i<n;i+=2){
    /*0:空いてる 1:空いてない
    0+0->00 or 11(-x^2倍)
    0+1->01
    1+0->10
    1+1->11
    */
    auto [f0,f1] = polys[i];
    auto [g0,g1] = polys[i+1];
    array<vector<mint>,4> merge;
    merge[1] = convolution(f0,g1);
    merge[2] = convolution(f1,g0);
    merge[3] = convolution(f1,g1);
    auto h = convolution(f0,g0);
    merge[0] = h;
    
    int p = h.size();
    if (merge[3].size() < p+2){
      merge[3].resize(p+2);
    }
    rep(i,p) merge[3][i+2] -= h[i];
    
    nxt_polys.push_back(merge);
  }
  
  auto calc = [&](auto self,int l,int r)->array<vector<mint>,4> {
    if (r-l==1){
      return nxt_polys[l];
    }
    int mid = (l+r)>>1;
    auto left = self(self,l,mid);
    auto right = self(self,mid,r);
    int merge_n = 0;
    rep(i,4){
      rep(j,4){
        int nxt_ocu = (i & 2) + (j & 1);
        int mid_ocu = (i & 1) + (j & 2);
        int h_deg = left[i].size() + right[j].size() - 2;
        //debug(i);
        //debug(j);
        //debug(h_deg);
        if (mid_ocu == 0){
          chmax(merge_n,h_deg+2);
        }
        else if (mid_ocu == 3){
          chmax(merge_n,h_deg);
        }
      }
    }
    merge_n += 1;
    //debug(merge_n);
    array<vector<mint>,4> merge;
    rep(i,4) merge[i] = vector<mint>(merge_n,0);
    rep(i,4){
      rep(j,4){
        int nxt_ocu = (i & 2) + (j & 1);
        int mid_ocu = (i & 1) + (j & 2);
        if (mid_ocu!=0 && mid_ocu!=3) continue;
        auto h = convolution(left[i],right[j]);
        //assert ((h.size()-1) == left[i].size() + right[j].size() - 2);
        if (mid_ocu == 0){
          //assert (h.size()+1 < merge_n);
          rep(k,h.size()){
            merge[nxt_ocu][k+2] -= h[k];
          }
        }
        else if (mid_ocu == 3){
          //debug(i);
          //debug(j);
          //debug(h.size());
          //assert (h.size()-1 < merge_n);
          rep(k,h.size()){
            merge[nxt_ocu][k] += h[k];
          }
        }
      }
    }
    return merge;
  };
  auto merged = calc(calc,0,nxt_polys.size());
  vector<mint> res0 = merged[1], res1 = merged[3];
  return {res0,res1};
}
mint BostanMori2(vec<mint> P,vec<mint> Q, long long N){
  for (; N; N>>= 1){
    vec<mint> Q_neg = {all(Q)};
    for (int i=1;i<Q_neg.size();i+=2) Q_neg[i] *= -1;
    P = convolution(P,Q_neg);
    Q = convolution(Q,Q_neg);
    int deg_P = P.size() - 1, deg_Q = Q.size() - 1;
    vec<mint> nxt_P((deg_P/2+1),0),nxt_Q((deg_Q/2+1),0);
    for (int i = (N & 1);i <= deg_P; i+=2){
      nxt_P[i>>1] = P[i];
    }
    for (int i = 0; i <= deg_Q; i+=2){
      nxt_Q[i>>1] = Q[i];
    }
    swap(P,nxt_P);
    swap(Q,nxt_Q);
  }
  return P[0];
}
void solve(){
  int N,M,S,T;
  cin>>N>>M>>S>>T;
  S--; T--;
  vec<vec<int>> edge(N);
  rep(i,N-1){
    int a,b;
    cin>>a>>b;
    a--; b--;
    edge[a].push_back(b);
    edge[b].push_back(a);
  }
  vec<int> parent(N,-1), sz(N,1);
  vec<int> heavy_child(N,-1);
  int root = T;
  auto dfs0 = [&](auto self,int v,int pv)->void {
    int tmp_max_sz = 0;
    for (auto nv:edge[v]){
      if (nv == pv) continue;
      parent[nv] = v;
      self(self,nv,v);
      sz[v] += sz[nv];
      if (sz[nv] > tmp_max_sz){
        tmp_max_sz = sz[nv];
        heavy_child[v] = nv;
      }
    }
  };
  dfs0(dfs0,root,-1);
  vector<int> on_st_path(N,0);
  on_st_path[T] = 1;
  {
    int pos = S;
    while (pos!=T){
      on_st_path[pos] = 1;
      pos = parent[pos];
    }
  }
  
  
  vector<int> to_parent_is_light(N,1);
  rep(v,N){
    if (heavy_child[v]!=-1){
      to_parent_is_light[heavy_child[v]] = 0;
    }
  }
  rep(v,N){
    if (on_st_path[v]) to_parent_is_light[v] = 1;
  }
  vec<pair<vec<mint>,vec<mint>>> dp(N);
  auto dfs1 = [&](auto self,int v,int pv)->void {
    vector<pair<vec<mint>,vec<mint>>> child_polys;
    for (auto nv:edge[v]){
      if (nv == parent[v] || on_st_path[nv]) continue;
      self(self,nv,v);
      if (nv!=heavy_child[v]){
        child_polys.push_back(dp[nv]);
        dp[nv].first.clear();
        dp[nv].second.clear();
      }
    }
    dp[v] = merge_child(child_polys);
    if (to_parent_is_light[v]){
      vector<pair<vec<mint>,vec<mint>>> path_polys = {dp[v]};
      int pos = heavy_child[v];
      while (pos!=-1 && !on_st_path[pos]){
        path_polys.push_back(dp[pos]);
        pos = heavy_child[pos];
      }
      dp[v] = merge_path(path_polys);
    }
  };
  rep(v,N){
    if (on_st_path[v]) dfs1(dfs1,v,-1);
  }
  vector<pair<vec<mint>,vec<mint>>> st_path_polys;
  vector<vec<mint>> st_path_polys_non;
  {
    int pos = S;
    while (pos!=-1){
      st_path_polys.push_back(dp[pos]);
      st_path_polys_non.push_back(dp[pos].first);
      pos = parent[pos];
    }
  }
  //debug(st_path_polys);
  //debug(st_path_polys_non);
  
  vector<mint> q_det = merge_path(st_path_polys).second;
  vector<mint> p_det = calc_prod(st_path_polys_non);
  //debug(p_det);
  //debug(q_det);
  int D = int(st_path_polys.size()) - 1;
  if (M < D){
    cout << 0 << "\n";
    return ;
  }
  
  auto ans = BostanMori2(p_det,q_det,M-D);
  cout << ans << "\n";
}
int main(){
  ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int T = 1;
  while (T--){
    solve();
  }
  
  
  
}
            
            
            
        