結果

問題 No.2587 Random Walk on Tree
ユーザー chineristACchineristAC
提出日時 2023-12-16 00:33:26
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 7,242 ms / 10,000 ms
コード長 9,039 bytes
コンパイル時間 6,272 ms
コンパイル使用メモリ 244,936 KB
実行使用メモリ 69,924 KB
最終ジャッジ日時 2023-12-23 23:50:01
合計ジャッジ時間 108,837 ms
ジャッジサーバーID
(参考情報)
judge13 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,676 KB
testcase_01 AC 2 ms
6,676 KB
testcase_02 AC 2 ms
6,676 KB
testcase_03 AC 2 ms
6,676 KB
testcase_04 AC 3 ms
6,676 KB
testcase_05 AC 4 ms
6,676 KB
testcase_06 AC 4 ms
6,676 KB
testcase_07 AC 2 ms
6,676 KB
testcase_08 AC 5 ms
6,676 KB
testcase_09 AC 3 ms
6,676 KB
testcase_10 AC 3 ms
6,676 KB
testcase_11 AC 11 ms
6,676 KB
testcase_12 AC 51 ms
6,676 KB
testcase_13 AC 58 ms
6,676 KB
testcase_14 AC 20 ms
6,676 KB
testcase_15 AC 4,353 ms
40,304 KB
testcase_16 AC 2,244 ms
27,644 KB
testcase_17 AC 2,636 ms
28,948 KB
testcase_18 AC 302 ms
7,680 KB
testcase_19 AC 4,894 ms
53,648 KB
testcase_20 AC 4,220 ms
45,688 KB
testcase_21 AC 5,272 ms
49,052 KB
testcase_22 AC 4,439 ms
69,924 KB
testcase_23 AC 4,903 ms
52,056 KB
testcase_24 AC 4,422 ms
43,296 KB
testcase_25 AC 3,182 ms
57,812 KB
testcase_26 AC 7,242 ms
52,580 KB
testcase_27 AC 6,026 ms
47,144 KB
testcase_28 AC 6,356 ms
48,888 KB
testcase_29 AC 6,294 ms
50,272 KB
testcase_30 AC 5,170 ms
52,760 KB
testcase_31 AC 5,091 ms
54,324 KB
testcase_32 AC 6,205 ms
53,816 KB
testcase_33 AC 2 ms
6,676 KB
testcase_34 AC 1,836 ms
45,616 KB
testcase_35 AC 1,838 ms
45,860 KB
testcase_36 AC 1,840 ms
46,168 KB
testcase_37 AC 2,940 ms
45,508 KB
testcase_38 AC 4,347 ms
46,132 KB
testcase_39 AC 4,388 ms
46,232 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// -fsanitize=undefined,
// #define _GLIBCXX_DEBUG


#pragma GCC target("avx2")
#pragma GCC optimize("unroll-loops")

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <queue>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <random>
#include <stdio.h>
#include <fstream>
#include <functional>
#include <cassert>
#include <unordered_map>
#include <bitset>
#include <chrono>
#include <atcoder/modint>
#include <atcoder/convolution>


using namespace std;
using namespace atcoder;


#define rep(i,n) for (int i=0;i<n;i+=1)
#define rrep(i,n) for (int i=n-1;i>-1;i--)
#define pb push_back
#define all(x) (x).begin(), (x).end()

#define debug(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << " )\n";

template<class T>
using vec = vector<T>;
template<class T>
using vvec = vec<vec<T>>;
template<class T>
using vvvec = vec<vvec<T>>;
using ll = long long;
using pii = pair<int,int>;
using pll = pair<ll,ll>;


template<class T>
bool chmin(T &a, T b){
  if (a>b){
    a = b;
    return true;
  }
  return false;
}

template<class T>
bool chmax(T &a, T b){
  if (a<b){
    a = b;
    return true;
  }
  return false;
}

template<class T>
T sum(vec<T> x){
  T res=0;
  for (auto e:x){
    res += e;
  }
  return res;
}

template<class T>
void printv(vec<T> x){
  for (auto e:x){
    cout<<e<<" ";
  }
  cout<<endl;
}



template<class T,class U>
ostream& operator<<(ostream& os, const pair<T,U>& A){
  os << "(" << A.first <<", " << A.second << ")";
  return os;
}

template<class T>
ostream& operator<<(ostream& os, const set<T>& S){
  os << "set{";
  for (auto a:S){
    os << a;
    auto it = S.find(a);
    it++;
    if (it!=S.end()){
      os << ", ";
    }
  }
  os << "}";
  return os;
}

using mint = modint998244353;

ostream& operator<<(ostream& os, const mint& a){
  os << a.val();
  return os;
}

template<class T>
ostream& operator<<(ostream& os, const vec<T>& A){
  os << "[";
  rep(i,A.size()){
    os << A[i];
    if (i!=A.size()-1){
      os << ", ";
    }
  }
  os << "]" ;
  return os;
}

template <typename T> vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
    if (a.size() < b.size()) {
        a.resize(b.size());
    }
    for (int i = 0; i < (int)b.size(); i++) {
        a[i] += b[i];
    }
    return a;
}

template <typename T> vector<T> operator+(const vector<T>& a, const vector<T>& b) {
    vector<T> c = a;
    return c += b;
}


vector<mint> calc_prod(vector<vector<mint>> polys){
  if (polys.empty()) return {1};
  deque<vec<mint>> deq;
  for (auto f:polys) deq.push_back(f);
  while (deq.size() > 1){
    auto f = deq.front(); deq.pop_front();
    auto g = deq.front(); deq.pop_front();
    deq.push_back(convolution(f,g));
  }
  return deq[0];
}

pair<vector<mint>,vector<mint>> merge_child(vector<pair<vec<mint>,vec<mint>>> child_polys){
  if (child_polys.empty()){
    return {{1},{1,-1}};
  }

  deque<pair<vec<mint>,vec<mint>>> deq;
  for (auto e:child_polys) deq.push_back(e);
  

  while (deq.size() > 1){
    auto [f0,f1] = deq.front(); deq.pop_front();
    auto [g0,g1] = deq.front(); deq.pop_front();

    auto h1 = convolution(f1,g1);
    auto h0 = convolution(g0,f1) + convolution(f0,g1);
    deq.push_back({h0,h1});
  }

  auto [A,B] = deq.front();

  vector<mint> res0 = B;
  /*res1 = A * (-x^2) + B * (1-x) */

  //debug(child_polys);
  //debug(A);
  //debug(B);

  vector<mint> res1(max(A.size()+2,B.size()+1),0);
  rep(i,A.size()) res1[i+2] -= A[i];
  rep(i,B.size()) {
    res1[i] += B[i];
    res1[i+1] -= B[i];
  }

  return {res0,res1};
}

pair<vector<mint>,vector<mint>> merge_path(vector<pair<vector<mint>,vector<mint>>> polys){
  int n = polys.size();
  if (n & 1){
    polys.push_back({{0},{1}});
  }
  n = polys.size();

  //debug(polys);

  vector<array<vector<mint>,4>> nxt_polys;
  for (int i=0;i<n;i+=2){
    /*0:空いてる 1:空いてない
    0+0->00 or 11(-x^2倍)
    0+1->01
    1+0->10
    1+1->11
    */
    auto [f0,f1] = polys[i];
    auto [g0,g1] = polys[i+1];

    array<vector<mint>,4> merge;
    merge[1] = convolution(f0,g1);
    merge[2] = convolution(f1,g0);
    merge[3] = convolution(f1,g1);

    auto h = convolution(f0,g0);
    merge[0] = h;
    
    int p = h.size();
    if (merge[3].size() < p+2){
      merge[3].resize(p+2);
    }
    rep(i,p) merge[3][i+2] -= h[i];
    
    nxt_polys.push_back(merge);
  }



  

  auto calc = [&](auto self,int l,int r)->array<vector<mint>,4> {
    if (r-l==1){
      return nxt_polys[l];
    }

    int mid = (l+r)>>1;
    auto left = self(self,l,mid);
    auto right = self(self,mid,r);

    int merge_n = 0;
    rep(i,4){
      rep(j,4){
        int nxt_ocu = (i & 2) + (j & 1);
        int mid_ocu = (i & 1) + (j & 2);

        int h_deg = left[i].size() + right[j].size() - 2;
        //debug(i);
        //debug(j);
        //debug(h_deg);
        if (mid_ocu == 0){
          chmax(merge_n,h_deg+2);
        }
        else if (mid_ocu == 3){
          chmax(merge_n,h_deg);
        }
      }
    }
    merge_n += 1;
    //debug(merge_n);

    array<vector<mint>,4> merge;
    rep(i,4) merge[i] = vector<mint>(merge_n,0);


    rep(i,4){
      rep(j,4){
        int nxt_ocu = (i & 2) + (j & 1);
        int mid_ocu = (i & 1) + (j & 2);
        if (mid_ocu!=0 && mid_ocu!=3) continue;

        auto h = convolution(left[i],right[j]);
        //assert ((h.size()-1) == left[i].size() + right[j].size() - 2);
        if (mid_ocu == 0){
          //assert (h.size()+1 < merge_n);
          rep(k,h.size()){
            merge[nxt_ocu][k+2] -= h[k];
          }
        }
        else if (mid_ocu == 3){
          //debug(i);
          //debug(j);
          //debug(h.size());
          //assert (h.size()-1 < merge_n);
          rep(k,h.size()){
            merge[nxt_ocu][k] += h[k];
          }
        }
      }
    }

    return merge;
  };

  auto merged = calc(calc,0,nxt_polys.size());

  vector<mint> res0 = merged[1], res1 = merged[3];
  return {res0,res1};
}

mint BostanMori2(vec<mint> P,vec<mint> Q, long long N){
  for (; N; N>>= 1){
    vec<mint> Q_neg = {all(Q)};
    for (int i=1;i<Q_neg.size();i+=2) Q_neg[i] *= -1;
    P = convolution(P,Q_neg);
    Q = convolution(Q,Q_neg);

    int deg_P = P.size() - 1, deg_Q = Q.size() - 1;

    vec<mint> nxt_P((deg_P/2+1),0),nxt_Q((deg_Q/2+1),0);
    for (int i = (N & 1);i <= deg_P; i+=2){
      nxt_P[i>>1] = P[i];
    }
    for (int i = 0; i <= deg_Q; i+=2){
      nxt_Q[i>>1] = Q[i];
    }


    swap(P,nxt_P);
    swap(Q,nxt_Q);
  }
  return P[0];
}

void solve(){
  int N,M,S,T;
  cin>>N>>M>>S>>T;
  S--; T--;
  vec<vec<int>> edge(N);
  rep(i,N-1){
    int a,b;
    cin>>a>>b;
    a--; b--;
    edge[a].push_back(b);
    edge[b].push_back(a);
  }

  vec<int> parent(N,-1), sz(N,1);
  vec<int> heavy_child(N,-1);
  int root = T;

  auto dfs0 = [&](auto self,int v,int pv)->void {
    int tmp_max_sz = 0;
    for (auto nv:edge[v]){
      if (nv == pv) continue;
      parent[nv] = v;
      self(self,nv,v);
      sz[v] += sz[nv];
      if (sz[nv] > tmp_max_sz){
        tmp_max_sz = sz[nv];
        heavy_child[v] = nv;
      }
    }
  };

  dfs0(dfs0,root,-1);

  vector<int> on_st_path(N,0);
  on_st_path[T] = 1;
  {
    int pos = S;
    while (pos!=T){
      on_st_path[pos] = 1;
      pos = parent[pos];
    }
  }
  
  

  vector<int> to_parent_is_light(N,1);
  rep(v,N){
    if (heavy_child[v]!=-1){
      to_parent_is_light[heavy_child[v]] = 0;
    }
  }
  rep(v,N){
    if (on_st_path[v]) to_parent_is_light[v] = 1;
  }

  vec<pair<vec<mint>,vec<mint>>> dp(N);

  auto dfs1 = [&](auto self,int v,int pv)->void {
    vector<pair<vec<mint>,vec<mint>>> child_polys;
    for (auto nv:edge[v]){
      if (nv == parent[v] || on_st_path[nv]) continue;
      self(self,nv,v);
      if (nv!=heavy_child[v]){
        child_polys.push_back(dp[nv]);
        dp[nv].first.clear();
        dp[nv].second.clear();
      }
    }
    dp[v] = merge_child(child_polys);

    if (to_parent_is_light[v]){
      vector<pair<vec<mint>,vec<mint>>> path_polys = {dp[v]};
      int pos = heavy_child[v];
      while (pos!=-1 && !on_st_path[pos]){
        path_polys.push_back(dp[pos]);
        pos = heavy_child[pos];
      }
      dp[v] = merge_path(path_polys);
    }
  };

  rep(v,N){
    if (on_st_path[v]) dfs1(dfs1,v,-1);
  }

  vector<pair<vec<mint>,vec<mint>>> st_path_polys;
  vector<vec<mint>> st_path_polys_non;
  {
    int pos = S;
    while (pos!=-1){
      st_path_polys.push_back(dp[pos]);
      st_path_polys_non.push_back(dp[pos].first);
      pos = parent[pos];
    }
  }

  //debug(st_path_polys);
  //debug(st_path_polys_non);
  

  vector<mint> q_det = merge_path(st_path_polys).second;
  vector<mint> p_det = calc_prod(st_path_polys_non);

  //debug(p_det);
  //debug(q_det);

  int D = int(st_path_polys.size()) - 1;
  if (M < D){
    cout << 0 << "\n";
    return ;
  }

  

  auto ans = BostanMori2(p_det,q_det,M-D);
  cout << ans << "\n";





}

int main(){
  ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  int T = 1;
  while (T--){
    solve();
  }

  
  


  
}
0