結果
問題 | No.2587 Random Walk on Tree |
ユーザー | tko919 |
提出日時 | 2023-12-25 04:47:05 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 34,477 bytes |
コンパイル時間 | 4,568 ms |
コンパイル使用メモリ | 268,168 KB |
実行使用メモリ | 17,992 KB |
最終ジャッジ日時 | 2024-09-27 14:08:53 |
合計ジャッジ時間 | 17,467 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,816 KB |
testcase_01 | AC | 2 ms
6,944 KB |
testcase_02 | AC | 2 ms
6,944 KB |
testcase_03 | WA | - |
testcase_04 | WA | - |
testcase_05 | WA | - |
testcase_06 | WA | - |
testcase_07 | WA | - |
testcase_08 | WA | - |
testcase_09 | WA | - |
testcase_10 | WA | - |
testcase_11 | WA | - |
testcase_12 | WA | - |
testcase_13 | WA | - |
testcase_14 | WA | - |
testcase_15 | TLE | - |
testcase_16 | -- | - |
testcase_17 | -- | - |
testcase_18 | -- | - |
testcase_19 | -- | - |
testcase_20 | -- | - |
testcase_21 | -- | - |
testcase_22 | -- | - |
testcase_23 | -- | - |
testcase_24 | -- | - |
testcase_25 | -- | - |
testcase_26 | -- | - |
testcase_27 | -- | - |
testcase_28 | -- | - |
testcase_29 | -- | - |
testcase_30 | -- | - |
testcase_31 | -- | - |
testcase_32 | -- | - |
testcase_33 | -- | - |
testcase_34 | -- | - |
testcase_35 | -- | - |
testcase_36 | -- | - |
testcase_37 | -- | - |
testcase_38 | -- | - |
testcase_39 | -- | - |
ソースコード
#line 1 "library/Template/template.hpp" #include <bits/stdc++.h> using namespace std; #define rep(i,a,b) for(int i=(int)(a);i<(int)(b);i++) #define ALL(v) (v).begin(),(v).end() #define UNIQUE(v) sort(ALL(v)),(v).erase(unique(ALL(v)),(v).end()) #define SZ(v) (int)v.size() #define MIN(v) *min_element(ALL(v)) #define MAX(v) *max_element(ALL(v)) #define LB(v,x) int(lower_bound(ALL(v),(x))-(v).begin()) #define UB(v,x) int(upper_bound(ALL(v),(x))-(v).begin()) using ll=long long int; using ull=unsigned long long; const int inf = 0x3fffffff; const ll INF = 0x1fffffffffffffff; template<typename T>inline bool chmax(T& a,T b){if(a<b){a=b;return 1;}return 0;} template<typename T>inline bool chmin(T& a,T b){if(a>b){a=b;return 1;}return 0;} template<typename T,typename U>T ceil(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?(x+y-1)/y:x/y);} template<typename T,typename U>T floor(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?x/y:(x-y+1)/y);} template<typename T>int popcnt(T x){return __builtin_popcountll(x);} template<typename T>int topbit(T x){return (x==0?-1:63-__builtin_clzll(x));} template<typename T>int lowbit(T x){return (x==0?-1:__builtin_ctzll(x));} #line 2 "library/Utility/fastio.hpp" #include <unistd.h> class FastIO { static constexpr int L = 1 << 16; char rdbuf[L]; int rdLeft = 0, rdRight = 0; inline void reload() { int len = rdRight - rdLeft; memmove(rdbuf, rdbuf + rdLeft, len); rdLeft = 0, rdRight = len; rdRight += fread(rdbuf + len, 1, L - len, stdin); } inline bool skip() { for (;;) { while (rdLeft != rdRight and rdbuf[rdLeft] <= ' ') rdLeft++; if (rdLeft == rdRight) { reload(); if (rdLeft == rdRight) return false; } else break; } return true; } template <typename T, enable_if_t<is_integral<T>::value, int> = 0> inline bool _read(T &x) { if (!skip()) return false; if (rdLeft + 20 >= rdRight) reload(); bool neg = false; if (rdbuf[rdLeft] == '-') { neg = true; rdLeft++; } x = 0; while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) { x = x * 10 + (neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48)); } return true; } inline bool _read(__int128_t &x) { if (!skip()) return false; if (rdLeft + 40 >= rdRight) reload(); bool neg = false; if (rdbuf[rdLeft] == '-') { neg = true; rdLeft++; } x = 0; while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) { x = x * 10 + (neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48)); } return true; } inline bool _read(__uint128_t &x) { if (!skip()) return false; if (rdLeft + 40 >= rdRight) reload(); x = 0; while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) { x = x * 10 + (rdbuf[rdLeft++] ^ 48); } return true; } template <typename T, enable_if_t<is_floating_point<T>::value, int> = 0> inline bool _read(T &x) { if (!skip()) return false; if (rdLeft + 20 >= rdRight) reload(); bool neg = false; if (rdbuf[rdLeft] == '-') { neg = true; rdLeft++; } x = 0; while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and rdLeft < rdRight) { x = x * 10 + (rdbuf[rdLeft++] ^ 48); } if (rdbuf[rdLeft] != '.') return true; rdLeft++; T base = .1; while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and rdLeft < rdRight) { x += base * (rdbuf[rdLeft++] ^ 48); base *= .1; } if (neg) x = -x; return true; } inline bool _read(char &x) { if (!skip()) return false; if (rdLeft + 1 >= rdRight) reload(); x = rdbuf[rdLeft++]; return true; } inline bool _read(string &x) { if (!skip()) return false; for (;;) { int pos = rdLeft; while (pos < rdRight and rdbuf[pos] > ' ') pos++; x.append(rdbuf + rdLeft, pos - rdLeft); if (rdLeft == pos) break; rdLeft = pos; if (rdLeft == rdRight) reload(); else break; } return true; } template <typename T> inline bool _read(vector<T> &v) { for (auto &x : v) { if (!_read(x)) return false; } return true; } char wtbuf[L], tmp[50]; int wtRight = 0; inline void _write(const char &x) { if (wtRight > L - 32) flush(); wtbuf[wtRight++] = x; } inline void _write(const string &x) { for (auto &c : x) _write(c); } template <typename T, enable_if_t<is_integral<T>::value, int> = 0> inline void _write(T x) { if (wtRight > L - 32) flush(); if (x == 0) { _write('0'); return; } else if (x < 0) { _write('-'); if (__builtin_expect(x == std::numeric_limits<T>::min(), 0)) { switch (sizeof(x)) { case 2: _write("32768"); return; case 4: _write("2147483648"); return; case 8: _write("9223372036854775808"); return; } } x = -x; } int pos = 0; while (x != 0) { tmp[pos++] = char((x % 10) | 48); x /= 10; } rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i]; wtRight += pos; } inline void _write(__int128_t x) { if (wtRight > L - 40) flush(); if (x == 0) { _write('0'); return; } else if (x < 0) { _write('-'); x = -x; } int pos = 0; while (x != 0) { tmp[pos++] = char((x % 10) | 48); x /= 10; } rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i]; wtRight += pos; } inline void _write(__uint128_t x) { if (wtRight > L - 40) flush(); if (x == 0) { _write('0'); return; } int pos = 0; while (x != 0) { tmp[pos++] = char((x % 10) | 48); x /= 10; } rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i]; wtRight += pos; } inline void _write(double x) { ostringstream oss; oss << fixed << setprecision(15) << double(x); string s = oss.str(); _write(s); } template <typename T> inline void _write(const vector<T> &v) { rep(i, 0, v.size()) { if (i) _write(' '); _write(v[i]); } } public: FastIO() {} ~FastIO() { flush(); } inline void read() {} template <typename Head, typename... Tail> inline void read(Head &head, Tail &...tail) { assert(_read(head)); read(tail...); } template <bool ln = true, bool space = false> inline void write() { if (ln) _write('\n'); } template <bool ln = true, bool space = true, typename Head, typename... Tail> inline void write(const Head &head, const Tail &...tail) { _write(head); if (space) _write(' '); write<ln, true>(tail...); } inline void flush() { fwrite(wtbuf, 1, wtRight, stdout); wtRight = 0; } }; /** * @brief Fast IO */ #line 3 "sol.cpp" #line 2 "library/Convolution/ntt.hpp" template <typename T> struct NTT { static constexpr int rank2 = __builtin_ctzll(T::get_mod() - 1); std::array<T, rank2 + 1> root; // root[i]^(2^i) == 1 std::array<T, rank2 + 1> iroot; // root[i] * iroot[i] == 1 std::array<T, std::max(0, rank2 - 2 + 1)> rate2; std::array<T, std::max(0, rank2 - 2 + 1)> irate2; std::array<T, std::max(0, rank2 - 3 + 1)> rate3; std::array<T, std::max(0, rank2 - 3 + 1)> irate3; NTT() { T g = 2; while (g.pow((T::get_mod() - 1) >> 1) == 1) { g += 1; } root[rank2] = g.pow((T::get_mod() - 1) >> rank2); iroot[rank2] = root[rank2].inv(); for (int i = rank2 - 1; i >= 0; i--) { root[i] = root[i + 1] * root[i + 1]; iroot[i] = iroot[i + 1] * iroot[i + 1]; } { T prod = 1, iprod = 1; for (int i = 0; i <= rank2 - 2; i++) { rate2[i] = root[i + 2] * prod; irate2[i] = iroot[i + 2] * iprod; prod *= iroot[i + 2]; iprod *= root[i + 2]; } } { T prod = 1, iprod = 1; for (int i = 0; i <= rank2 - 3; i++) { rate3[i] = root[i + 3] * prod; irate3[i] = iroot[i + 3] * iprod; prod *= iroot[i + 3]; iprod *= root[i + 3]; } } } void ntt(std::vector<T> &a, bool type = 0) { int n = int(a.size()); int h = __builtin_ctzll((unsigned int)n); if (type) { int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed while (len) { if (len == 1) { int p = 1 << (h - len); T irot = 1; for (int s = 0; s < (1 << (len - 1)); s++) { int offset = s << (h - len + 1); for (int i = 0; i < p; i++) { auto l = a[i + offset]; auto r = a[i + offset + p]; a[i + offset] = l + r; a[i + offset + p] = (unsigned long long)(T::get_mod() + l.v - r.v) * irot.v; ; } if (s + 1 != (1 << (len - 1))) irot *= irate2[__builtin_ctzll(~(unsigned int)(s))]; } len--; } else { // 4-base int p = 1 << (h - len); T irot = 1, iimag = iroot[2]; for (int s = 0; s < (1 << (len - 2)); s++) { T irot2 = irot * irot; T irot3 = irot2 * irot; int offset = s << (h - len + 2); for (int i = 0; i < p; i++) { auto a0 = 1ULL * a[i + offset + 0 * p].v; auto a1 = 1ULL * a[i + offset + 1 * p].v; auto a2 = 1ULL * a[i + offset + 2 * p].v; auto a3 = 1ULL * a[i + offset + 3 * p].v; auto a2na3iimag = 1ULL * T((T::get_mod() + a2 - a3) * iimag.v).v; a[i + offset] = a0 + a1 + a2 + a3; a[i + offset + 1 * p] = (a0 + (T::get_mod() - a1) + a2na3iimag) * irot.v; a[i + offset + 2 * p] = (a0 + a1 + (T::get_mod() - a2) + (T::get_mod() - a3)) * irot2.v; a[i + offset + 3 * p] = (a0 + (T::get_mod() - a1) + (T::get_mod() - a2na3iimag)) * irot3.v; } if (s + 1 != (1 << (len - 2))) irot *= irate3[__builtin_ctzll(~(unsigned int)(s))]; } len -= 2; } } T e = T(n).inv(); for (auto &x : a) x *= e; } else { int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed while (len < h) { if (h - len == 1) { int p = 1 << (h - len - 1); T rot = 1; for (int s = 0; s < (1 << len); s++) { int offset = s << (h - len); for (int i = 0; i < p; i++) { auto l = a[i + offset]; auto r = a[i + offset + p] * rot; a[i + offset] = l + r; a[i + offset + p] = l - r; } if (s + 1 != (1 << len)) rot *= rate2[__builtin_ctzll(~(unsigned int)(s))]; } len++; } else { // 4-base int p = 1 << (h - len - 2); T rot = 1, imag = root[2]; for (int s = 0; s < (1 << len); s++) { T rot2 = rot * rot; T rot3 = rot2 * rot; int offset = s << (h - len); for (int i = 0; i < p; i++) { auto mod2 = 1ULL * T::get_mod() * T::get_mod(); auto a0 = 1ULL * a[i + offset].v; auto a1 = 1ULL * a[i + offset + p].v * rot.v; auto a2 = 1ULL * a[i + offset + 2 * p].v * rot2.v; auto a3 = 1ULL * a[i + offset + 3 * p].v * rot3.v; auto a1na3imag = 1ULL * T(a1 + mod2 - a3).v * imag.v; auto na2 = mod2 - a2; a[i + offset] = a0 + a2 + a1 + a3; a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3)); a[i + offset + 2 * p] = a0 + na2 + a1na3imag; a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); } if (s + 1 != (1 << len)) rot *= rate3[__builtin_ctzll(~(unsigned int)(s))]; } len += 2; } } } } vector<T> mult(const vector<T> &a, const vector<T> &b) { if (a.empty() or b.empty()) return vector<T>(); int as = a.size(), bs = b.size(); int n = as + bs - 1; if (as <= 30 or bs <= 30) { if (as > 30) return mult(b, a); vector<T> res(n); rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j]; return res; } int m = 1; while (m < n) m <<= 1; vector<T> res(m); rep(i, 0, as) res[i] = a[i]; ntt(res); if (a == b) rep(i, 0, m) res[i] *= res[i]; else { vector<T> c(m); rep(i, 0, bs) c[i] = b[i]; ntt(c); rep(i, 0, m) res[i] *= c[i]; } ntt(res, 1); res.resize(n); return res; } }; /** * @brief Number Theoretic Transform */ #line 2 "library/FPS/fps.hpp" template <typename T> struct Poly : vector<T> { Poly(int n = 0) { this->assign(n, T()); } Poly(const initializer_list<T> f) : vector<T>::vector(f) {} Poly(const vector<T> &f) { this->assign(ALL(f)); } T eval(const T &x) { T res; for (int i = this->size() - 1; i >= 0; i--) res *= x, res += this->at(i); return res; } Poly rev() const { Poly res = *this; reverse(ALL(res)); return res; } void shrink() { while (!this->empty() and this->back() == 0) this->pop_back(); } Poly operator>>(int sz) const { if ((int)this->size() <= sz) return {}; Poly ret(*this); ret.erase(ret.begin(), ret.begin() + sz); return ret; } Poly operator<<(int sz) const { Poly ret(*this); ret.insert(ret.begin(), sz, T(0)); return ret; } Poly<T> mult(const Poly<T> &a, const Poly<T> &b) { if (a.empty() or b.empty()) return {}; int as = a.size(), bs = b.size(); int n = as + bs - 1; if (as <= 30 or bs <= 30) { if (as > 30) return mult(b, a); Poly<T> res(n); rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j]; return res; } int m = 1; while (m < n) m <<= 1; Poly<T> res(m); rep(i, 0, as) res[i] = a[i]; NTT(res, 0); if (a == b) rep(i, 0, m) res[i] *= res[i]; else { Poly<T> c(m); rep(i, 0, bs) c[i] = b[i]; NTT(c, 0); rep(i, 0, m) res[i] *= c[i]; } NTT(res, 1); res.resize(n); return res; } Poly square() const { return Poly(mult(*this, *this)); } Poly operator-() const { return Poly() - *this; } Poly operator+(const Poly &g) const { return Poly(*this) += g; } Poly operator+(const T &g) const { return Poly(*this) += g; } Poly operator-(const Poly &g) const { return Poly(*this) -= g; } Poly operator-(const T &g) const { return Poly(*this) -= g; } Poly operator*(const Poly &g) const { return Poly(*this) *= g; } Poly operator*(const T &g) const { return Poly(*this) *= g; } Poly operator/(const Poly &g) const { return Poly(*this) /= g; } Poly operator/(const T &g) const { return Poly(*this) /= g; } Poly operator%(const Poly &g) const { return Poly(*this) %= g; } pair<Poly, Poly> divmod(const Poly &g) const { Poly q = *this / g, r = *this - g * q; r.shrink(); return {q, r}; } Poly &operator+=(const Poly &g) { if (g.size() > this->size()) this->resize(g.size()); rep(i, 0, g.size()) { (*this)[i] += g[i]; } return *this; } Poly &operator+=(const T &g) { if (this->empty()) this->push_back(0); (*this)[0] += g; return *this; } Poly &operator-=(const Poly &g) { if (g.size() > this->size()) this->resize(g.size()); rep(i, 0, g.size()) { (*this)[i] -= g[i]; } return *this; } Poly &operator-=(const T &g) { if (this->empty()) this->push_back(0); (*this)[0] -= g; return *this; } Poly &operator*=(const Poly &g) { *this = mult(*this, g); return *this; } Poly &operator*=(const T &g) { rep(i, 0, this->size())(*this)[i] *= g; return *this; } Poly &operator/=(const Poly &g) { if (g.size() > this->size()) { this->clear(); return *this; } Poly g2 = g; reverse(ALL(*this)); reverse(ALL(g2)); int n = this->size() - g2.size() + 1; this->resize(n); g2.resize(n); *this *= g2.inv(); this->resize(n); reverse(ALL(*this)); shrink(); return *this; } Poly &operator/=(const T &g) { rep(i, 0, this->size())(*this)[i] /= g; return *this; } Poly &operator%=(const Poly &g) { *this -= *this / g * g; shrink(); return *this; } Poly diff() const { Poly res(this->size() - 1); rep(i, 0, res.size()) res[i] = (*this)[i + 1] * (i + 1); return res; } Poly inte() const { Poly res(this->size() + 1); for (int i = res.size() - 1; i; i--) res[i] = (*this)[i - 1] / i; return res; } Poly log() const { assert(this->front() == 1); const int n = this->size(); Poly res = diff() * inv(); res = res.inte(); res.resize(n); return res; } Poly shift(const int &c) const { const int n = this->size(); Poly res = *this, g(n); g[0] = 1; rep(i, 1, n) g[i] = g[i - 1] * c / i; vector<T> fact(n, 1); rep(i, 0, n) { if (i) fact[i] = fact[i - 1] * i; res[i] *= fact[i]; } res = res.rev(); res *= g; res.resize(n); res = res.rev(); rep(i, 0, n) res[i] /= fact[i]; return res; } Poly inv() const { const int n = this->size(); Poly res(1); res.front() = T(1) / this->front(); for (int k = 1; k < n; k <<= 1) { Poly f(k * 2), g(k * 2); rep(i, 0, min(n, k * 2)) f[i] = (*this)[i]; rep(i, 0, k) g[i] = res[i]; NTT(f, 0); NTT(g, 0); rep(i, 0, k * 2) f[i] *= g[i]; NTT(f, 1); rep(i, 0, k) { f[i] = 0; f[i + k] = -f[i + k]; } NTT(f, 0); rep(i, 0, k * 2) f[i] *= g[i]; NTT(f, 1); rep(i, 0, k) f[i] = res[i]; swap(res, f); } res.resize(n); return res; } Poly exp() const { const int n = this->size(); if (n == 1) return Poly({T(1)}); Poly b(2), c(1), z1, z2(2); b[0] = c[0] = z2[0] = z2[1] = 1; b[1] = (*this)[1]; for (int k = 2; k < n; k <<= 1) { Poly y = b; y.resize(k * 2); NTT(y, 0); z1 = z2; Poly z(k); rep(i, 0, k) z[i] = y[i] * z1[i]; NTT(z, 1); rep(i, 0, k >> 1) z[i] = 0; NTT(z, 0); rep(i, 0, k) z[i] *= -z1[i]; NTT(z, 1); c.insert(c.end(), z.begin() + (k >> 1), z.end()); z2 = c; z2.resize(k * 2); NTT(z2, 0); Poly x = *this; x.resize(k); x = x.diff(); x.resize(k); NTT(x, 0); rep(i, 0, k) x[i] *= y[i]; NTT(x, 1); Poly bb = b.diff(); rep(i, 0, k - 1) x[i] -= bb[i]; x.resize(k * 2); rep(i, 0, k - 1) { x[k + i] = x[i]; x[i] = 0; } NTT(x, 0); rep(i, 0, k * 2) x[i] *= z2[i]; NTT(x, 1); x.pop_back(); x = x.inte(); rep(i, k, min(n, k * 2)) x[i] += (*this)[i]; rep(i, 0, k) x[i] = 0; NTT(x, 0); rep(i, 0, k * 2) x[i] *= y[i]; NTT(x, 1); b.insert(b.end(), x.begin() + k, x.end()); } b.resize(n); return b; } Poly pow(ll t) { if (t == 0) { Poly res(this->size()); res[0] = 1; return res; } int n = this->size(), k = 0; while (k < n and (*this)[k] == 0) k++; Poly res(n); if (__int128_t(t) * k >= n) return res; n -= t * k; Poly g(n); T c = (*this)[k], ic = c.inv(); rep(i, 0, n) g[i] = (*this)[i + k] * ic; g = g.log(); for (auto &x : g) x *= t; g = g.exp(); c = c.pow(t); rep(i, 0, n) res[i + t * k] = g[i] * c; return res; } void NTT(vector<T> &a, bool inv) const; }; /** * @brief Formal Power Series (NTT-friendly mod) */ #line 2 "library/Math/modint.hpp" template <int mod = 1000000007> struct fp { int v; static constexpr int get_mod() { return mod; } int inv() const { int tmp, a = v, b = mod, x = 1, y = 0; while (b) tmp = a / b, a -= tmp * b, swap(a, b), x -= tmp * y, swap(x, y); if (x < 0) { x += mod; } return x; } fp(ll x = 0) : v(x >= 0 ? x % mod : (mod - (-x) % mod) % mod) {} fp operator-() const { return fp() - *this; } fp pow(ll t) { assert(t >= 0); fp res = 1, b = *this; while (t) { if (t & 1) res *= b; b *= b; t >>= 1; } return res; } fp &operator+=(const fp &x) { if ((v += x.v) >= mod) v -= mod; return *this; } fp &operator-=(const fp &x) { if ((v += mod - x.v) >= mod) v -= mod; return *this; } fp &operator*=(const fp &x) { v = ll(v) * x.v % mod; return *this; } fp &operator/=(const fp &x) { v = ll(v) * x.inv() % mod; return *this; } fp operator+(const fp &x) const { return fp(*this) += x; } fp operator-(const fp &x) const { return fp(*this) -= x; } fp operator*(const fp &x) const { return fp(*this) *= x; } fp operator/(const fp &x) const { return fp(*this) /= x; } bool operator==(const fp &x) const { return v == x.v; } bool operator!=(const fp &x) const { return v != x.v; } friend istream &operator>>(istream &is, fp &x) { return is >> x.v; } friend ostream &operator<<(ostream &os, const fp &x) { return os << x.v; } }; template <typename T> T Inv(ll n) { static const int md = T::get_mod(); static vector<T> buf({0, 1}); assert(n > 0); n %= md; while (SZ(buf) <= n) { int k = SZ(buf), q = (md + k - 1) / k; buf.push_back(buf[k * q - md] * q); } return buf[n]; } template <typename T> T Fact(ll n, bool inv = 0) { static const int md = T::get_mod(); static vector<T> buf({1, 1}), ibuf({1, 1}); assert(n >= 0 and n < md); while (SZ(buf) <= n) { buf.push_back(buf.back() * SZ(buf)); ibuf.push_back(ibuf.back() * Inv<T>(SZ(ibuf))); } return inv ? ibuf[n] : buf[n]; } template <typename T> T nPr(int n, int r, bool inv = 0) { if (n < 0 || n < r || r < 0) return 0; return Fact<T>(n, inv) * Fact<T>(n - r, inv ^ 1); } template <typename T> T nCr(int n, int r, bool inv = 0) { if (n < 0 || n < r || r < 0) return 0; return Fact<T>(n, inv) * Fact<T>(r, inv ^ 1) * Fact<T>(n - r, inv ^ 1); } template <typename T> T nHr(int n, int r, bool inv = 0) { return nCr<T>(n + r - 1, r, inv); } /** * @brief Modint */ #line 7 "sol.cpp" using Fp = fp<998244353>; NTT<Fp> ntt; template <> void Poly<Fp>::NTT(vector<Fp> &v, bool inv) const { return ntt.ntt(v, inv); } #line 2 "library/FPS/nthterm.hpp" template<typename T>T nth(Poly<T> p,Poly<T> q,ll n){ while(n){ Poly<T> base(q),np,nq; for(int i=1;i<(int)q.size();i+=2)base[i]=-base[i]; p*=base; q*=base; for(int i=n&1;i<(int)p.size();i+=2)np.emplace_back(p[i]); for(int i=0;i<(int)q.size();i+=2)nq.emplace_back(q[i]); swap(p,np); swap(q,nq); n>>=1; } return p[0]/q[0]; } /** * @brief Bostan-Mori Algorithm */ #line 2 "library/Math/matrix.hpp" template<class T>struct Matrix{ int h,w; vector<vector<T>> val; T det; Matrix(){} Matrix(int n):h(n),w(n),val(vector<vector<T>>(n,vector<T>(n))){} Matrix(int n,int m):h(n),w(m),val(vector<vector<T>>(n,vector<T>(m))){} vector<T>& operator[](const int i){return val[i];} Matrix& operator+=(const Matrix& m){ assert(h==m.h and w==m.w); rep(i,0,h)rep(j,0,w)val[i][j]+=m.val[i][j]; return *this; } Matrix& operator-=(const Matrix& m){ assert(h==m.h and w==m.w); rep(i,0,h)rep(j,0,w)val[i][j]-=m.val[i][j]; return *this; } Matrix& operator*=(const Matrix& m){ assert(w==m.h); Matrix<T> res(h,m.w); rep(i,0,h)rep(j,0,m.w)rep(k,0,w)res.val[i][j]+=val[i][k]*m.val[k][j]; *this=res; return *this; } Matrix operator+(const Matrix& m)const{return Matrix(*this)+=m;} Matrix operator-(const Matrix& m)const{return Matrix(*this)-=m;} Matrix operator*(const Matrix& m)const{return Matrix(*this)*=m;} Matrix pow(ll k){ Matrix<T> res(h,h),c=*this; rep(i,0,h)res.val[i][i]=1; while(k){if(k&1)res*=c; c*=c; k>>=1;} return res; } vector<int> gauss(int c=-1){ if(val.empty())return {}; if(c==-1)c=w; int cur=0; vector<int> res; det=1; rep(i,0,c){ if(cur==h)break; rep(j,cur,h)if(val[j][i]!=0){ swap(val[cur],val[j]); if(cur!=j)det*=-1; break; } det*=val[cur][i]; if(val[cur][i]==0)continue; rep(j,0,h)if(j!=cur){ T z=val[j][i]/val[cur][i]; rep(k,i,w)val[j][k]-=val[cur][k]*z; } res.push_back(i); cur++; } return res; } Matrix inv(){ assert(h==w); Matrix base(h,h*2),res(h,h); rep(i,0,h)rep(j,0,h)base[i][j]=val[i][j]; rep(i,0,h)base[i][h+i]=1; base.gauss(h); det=base.det; rep(i,0,h)rep(j,0,h)res[i][j]=base[i][h+j]/base[i][i]; return res; } bool operator==(const Matrix& m){ assert(h==m.h and w==m.w); rep(i,0,h)rep(j,0,w)if(val[i][j]!=m.val[i][j])return false; return true; } bool operator!=(const Matrix& m){ assert(h==m.h and w==m.w); rep(i,0,h)rep(j,0,w)if(val[i][j]==m.val[i][j])return false; return true; } friend istream& operator>>(istream& is,Matrix& m){ rep(i,0,m.h)rep(j,0,m.w)is>>m[i][j]; return is; } friend ostream& operator<<(ostream& os,Matrix& m){ rep(i,0,m.h){ rep(j,0,m.w)os<<m[i][j]<<(j==m.w-1 and i!=m.h-1?'\n':' '); } return os; } }; /** * @brief Matrix */ #line 15 "sol.cpp" FastIO io; int main() { int n, m, S, T; io.read(n, m, S, T); S--; T--; vector g(n, vector<int>()); rep(_, 0, n - 1) { int u, v; io.read(u, v); u--; v--; g[u].push_back(v); g[v].push_back(u); } vector<int> sz(n, 1); auto dfs1 = [&](auto &dfs1, int v, int p) -> void { int mx = -1; rep(i, 0, SZ(g[v])) { int to = g[v][i]; if (to == p) continue; dfs1(dfs1, to, v); sz[v] += sz[to]; if (chmax(mx, sz[to])) swap(g[v][i], g[v][0]); } }; dfs1(dfs1, T, -1); using P = pair<Poly<Fp>, Poly<Fp>>; using Mat = Matrix<Poly<Fp>>; function<vector<P>(int, int)> rake; function<P(int, int)> compress; rake = [&](int v, int p) -> vector<P> { if (SZ(g[v]) == 1 and g[v][0] == p) { return {P{Poly<Fp>({0}), Poly<Fp>({1})}}; } auto ret = rake(g[v][0], v); deque<P> deq; deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})}); rep(i, 1, SZ(g[v])) if (g[v][i] != p) { deq.push_back(compress(g[v][i], v)); } while (deq.size() > 1) { auto [A1, A2] = deq.front(); deq.pop_front(); auto [B1, B2] = deq.front(); deq.pop_front(); deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2}); } ret.push_back(deq.front()); return ret; }; compress = [&](int v, int p) -> P { auto fs = rake(v, p); auto rec = [&](auto &rec, int L, int R) -> Mat { if (R - L == 1) { auto [f, g] = fs[L]; Mat ret(2); ret[0][1] = g; ret[1][0] = -(g << 2); ret[1][1] = g - (g << 1) - (f << 2); return ret; } int mid = (L + R) >> 1; return rec(rec, L, mid) * rec(rec, mid, R); }; auto A = rec(rec, 0, SZ(fs)); A[0][1].shrink(); A[1][1].shrink(); return P{A[0][1], A[1][1]}; }; vector<P> fs; auto dfs2 = [&](auto &dfs2, int v, int p) -> bool { deque<P> deq; deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})}); bool onedge = 0; rep(i, 0, SZ(g[v])) if (g[v][i] != p) { if (dfs2(dfs2, g[v][i], v)) { onedge = 1; } else { deq.push_back(compress(g[v][i], v)); } } onedge |= (v == S); if (onedge) { while (deq.size() > 1) { auto [A1, A2] = deq.front(); deq.pop_front(); auto [B1, B2] = deq.front(); deq.pop_front(); deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2}); } auto ret = deq.front(); // cerr << v << '\n'; // for (auto &v : ret.first) // cerr << v.v << ' '; // cerr << '\n'; // for (auto &v : ret.second) // cerr << v.v << ' '; // cerr << "\n\n"; fs.push_back(ret); } return onedge; }; dfs2(dfs2, T, -1); auto rec = [&](auto &rec, int L, int R) -> Mat { if (R - L == 1) { auto [f, g] = fs[L]; Mat ret(2); ret[0][1] = g; ret[1][0] = -(g << 2); ret[1][1] = g - (g << 1) - (f << 2); return ret; } int mid = (L + R) >> 1; return rec(rec, L, mid) * rec(rec, mid, R); }; auto A = rec(rec, 0, SZ(fs)); Poly<Fp> num, den; den = A[1][1]; den.shrink(); deque<Poly<Fp>> deq; for (auto &[f, g] : fs) deq.push_back(g); while (deq.size() > 1) { auto A = deq.front(); deq.pop_front(); auto B = deq.front(); deq.pop_front(); deq.push_back(A * B); } num = deq.front(); num.shrink(); for (auto &v : num) cerr << v.v << ' '; cerr << '\n'; for (auto &v : den) cerr << v.v << ' '; cerr << "\n\n"; if (m < (SZ(fs) - 1)) { io.write(0); return 0; } Fp ret = nth(num, den, m - (SZ(fs) - 1)); io.write(ret.v); return 0; }