結果
問題 | No.2587 Random Walk on Tree |
ユーザー |
![]() |
提出日時 | 2023-12-15 02:18:19 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 7,038 ms / 10,000 ms |
コード長 | 25,803 bytes |
コンパイル時間 | 4,338 ms |
コンパイル使用メモリ | 247,840 KB |
最終ジャッジ日時 | 2025-02-18 11:11:06 |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 37 |
ソースコード
// https://ei1333.github.io/library/test/verify/yosupo-exp-of-formal-power-series.test.cpp// #line 1 "test/verify/yosupo-exp-of-formal-power-series.test.cpp"#define PROBLEM "https://judge.yosupo.jp/problem/exp_of_formal_power_series"// #line 1 "template/template.hpp"#include <bits/stdc++.h>using namespace std;using int64 = long long;const int mod = 1e9 + 7;const int64 infll = (1LL << 62) - 1;const int inf = (1 << 30) - 1;struct IoSetup {IoSetup() {cin.tie(nullptr);ios::sync_with_stdio(false);cout << fixed << setprecision(10);cerr << fixed << setprecision(10);}} iosetup;template <typename T1, typename T2>ostream& operator<<(ostream& os, const pair<T1, T2>& p) {os << p.first << " " << p.second;return os;}template <typename T1, typename T2>istream& operator>>(istream& is, pair<T1, T2>& p) {is >> p.first >> p.second;return is;}template <typename T> ostream& operator<<(ostream& os, const vector<T>& v) {for (int i = 0; i < (int)v.size(); i++) {os << v[i] << (i + 1 != v.size() ? " " : "");}return os;}template <typename T> istream& operator>>(istream& is, vector<T>& v) {for (T& in : v)is >> in;return is;}template <typename T1, typename T2> inline bool chmax(T1& a, T2 b) {return a < b && (a = b, true);}template <typename T1, typename T2> inline bool chmin(T1& a, T2 b) {return a > b && (a = b, true);}template <typename T = int64> vector<T> make_v(size_t a) {return vector<T>(a);}template <typename T, typename... Ts> auto make_v(size_t a, Ts... ts) {return vector<decltype(make_v<T>(ts...))>(a, make_v<T>(ts...));}template <typename T, typename V>typename enable_if<is_class<T>::value == 0>::type fill_v(T& t, const V& v) {t = v;}template <typename T, typename V>typename enable_if<is_class<T>::value != 0>::type fill_v(T& t, const V& v) {for (auto& e : t)fill_v(e, v);}template <typename F> struct FixPoint : F {explicit FixPoint(F&& f) : F(forward<F>(f)) {}template <typename... Args> decltype(auto) operator()(Args &&... args) const {return F::operator()(*this, forward<Args>(args)...);}};template <typename F> inline decltype(auto) MFP(F&& f) {return FixPoint<F>{forward<F>(f)};}// #line 4 "test/verify/yosupo-exp-of-formal-power-series.test.cpp"// #line 1 "math/combinatorics/mod-int.hpp"template <int mod> struct ModInt {int x;ModInt() : x(0) {}ModInt(int64_t y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}ModInt& operator+=(const ModInt& p) {if ((x += p.x) >= mod)x -= mod;return *this;}ModInt& operator-=(const ModInt& p) {if ((x += mod - p.x) >= mod)x -= mod;return *this;}ModInt& operator*=(const ModInt& p) {x = (int)(1LL * x * p.x % mod);return *this;}ModInt& operator/=(const ModInt& p) {*this *= p.inverse();return *this;}ModInt operator-() const { return ModInt(-x); }ModInt operator+(const ModInt& p) const { return ModInt(*this) += p; }ModInt operator-(const ModInt& p) const { return ModInt(*this) -= p; }ModInt operator*(const ModInt& p) const { return ModInt(*this) *= p; }ModInt operator/(const ModInt& p) const { return ModInt(*this) /= p; }bool operator==(const ModInt& p) const { return x == p.x; }bool operator!=(const ModInt& p) const { return x != p.x; }ModInt inverse() const {int a = x, b = mod, u = 1, v = 0, t;while (b > 0) {t = a / b;swap(a -= t * b, b);swap(u -= t * v, v);}return ModInt(u);}ModInt pow(int64_t n) const {ModInt ret(1), mul(x);while (n > 0) {if (n & 1)ret *= mul;mul *= mul;n >>= 1;}return ret;}friend ostream& operator<<(ostream& os, const ModInt& p) { return os << p.x; }friend istream& operator>>(istream& is, ModInt& a) {int64_t t;is >> t;a = ModInt<mod>(t);return (is);}static int get_mod() { return mod; }};using modint = ModInt<mod>;// #line 6 "test/verify/yosupo-exp-of-formal-power-series.test.cpp"// #line 2 "math/fps/formal-power-series-friendly-ntt.hpp"// #line 1 "math/fft/number-theoretic-transform-friendly-mod-int.hpp"/*** @brief Number Theoretic Transform Friendly ModInt*/template <typename Mint> struct NumberTheoreticTransformFriendlyModInt {static vector<Mint> roots, iroots, rate3, irate3;static int max_base;NumberTheoreticTransformFriendlyModInt() = default;static void init() {if (roots.empty()) {const unsigned mod = Mint::get_mod();assert(mod >= 3 && mod % 2 == 1);auto tmp = mod - 1;max_base = 0;while (tmp % 2 == 0)tmp >>= 1, max_base++;Mint root = 2;while (root.pow((mod - 1) >> 1) == 1) {root += 1;}assert(root.pow(mod - 1) == 1);roots.resize(max_base + 1);iroots.resize(max_base + 1);rate3.resize(max_base + 1);irate3.resize(max_base + 1);roots[max_base] = root.pow((mod - 1) >> max_base);iroots[max_base] = Mint(1) / roots[max_base];for (int i = max_base - 1; i >= 0; i--) {roots[i] = roots[i + 1] * roots[i + 1];iroots[i] = iroots[i + 1] * iroots[i + 1];}{Mint prod = 1, iprod = 1;for (int i = 0; i <= max_base - 3; i++) {rate3[i] = roots[i + 3] * prod;irate3[i] = iroots[i + 3] * iprod;prod *= iroots[i + 3];iprod *= roots[i + 3];}}}}static void ntt(vector<Mint>& a) {init();const int n = (int)a.size();assert((n & (n - 1)) == 0);int h = __builtin_ctz(n);assert(h <= max_base);int len = 0;Mint imag = roots[2];if (h & 1) {int p = 1 << (h - 1);Mint rot = 1;for (int i = 0; i < p; i++) {auto r = a[i + p];a[i + p] = a[i] - r;a[i] += r;}len++;}for (; len + 1 < h; len += 2) {int p = 1 << (h - len - 2);{ // s = 0for (int i = 0; i < p; i++) {auto a0 = a[i];auto a1 = a[i + p];auto a2 = a[i + 2 * p];auto a3 = a[i + 3 * p];auto a1na3imag = (a1 - a3) * imag;auto a0a2 = a0 + a2;auto a1a3 = a1 + a3;auto a0na2 = a0 - a2;a[i] = a0a2 + a1a3;a[i + 1 * p] = a0a2 - a1a3;a[i + 2 * p] = a0na2 + a1na3imag;a[i + 3 * p] = a0na2 - a1na3imag;}}Mint rot = rate3[0];for (int s = 1; s < (1 << len); s++) {int offset = s << (h - len);Mint rot2 = rot * rot;Mint rot3 = rot2 * rot;for (int i = 0; i < p; i++) {auto a0 = a[i + offset];auto a1 = a[i + offset + p] * rot;auto a2 = a[i + offset + 2 * p] * rot2;auto a3 = a[i + offset + 3 * p] * rot3;auto a1na3imag = (a1 - a3) * imag;auto a0a2 = a0 + a2;auto a1a3 = a1 + a3;auto a0na2 = a0 - a2;a[i + offset] = a0a2 + a1a3;a[i + offset + 1 * p] = a0a2 - a1a3;a[i + offset + 2 * p] = a0na2 + a1na3imag;a[i + offset + 3 * p] = a0na2 - a1na3imag;}rot *= rate3[__builtin_ctz(~s)];}}}static void intt(vector<Mint>& a, bool f = true) {init();const int n = (int)a.size();assert((n & (n - 1)) == 0);int h = __builtin_ctz(n);assert(h <= max_base);int len = h;Mint iimag = iroots[2];for (; len > 1; len -= 2) {int p = 1 << (h - len);{ // s = 0for (int i = 0; i < p; i++) {auto a0 = a[i];auto a1 = a[i + 1 * p];auto a2 = a[i + 2 * p];auto a3 = a[i + 3 * p];auto a2na3iimag = (a2 - a3) * iimag;auto a0na1 = a0 - a1;auto a0a1 = a0 + a1;auto a2a3 = a2 + a3;a[i] = a0a1 + a2a3;a[i + 1 * p] = (a0na1 + a2na3iimag);a[i + 2 * p] = (a0a1 - a2a3);a[i + 3 * p] = (a0na1 - a2na3iimag);}}Mint irot = irate3[0];for (int s = 1; s < (1 << (len - 2)); s++) {int offset = s << (h - len + 2);Mint irot2 = irot * irot;Mint irot3 = irot2 * irot;for (int i = 0; i < p; i++) {auto a0 = a[i + offset];auto a1 = a[i + offset + 1 * p];auto a2 = a[i + offset + 2 * p];auto a3 = a[i + offset + 3 * p];auto a2na3iimag = (a2 - a3) * iimag;auto a0na1 = a0 - a1;auto a0a1 = a0 + a1;auto a2a3 = a2 + a3;a[i + offset] = a0a1 + a2a3;a[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot;a[i + offset + 2 * p] = (a0a1 - a2a3) * irot2;a[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3;}irot *= irate3[__builtin_ctz(~s)];}}if (len >= 1) {int p = 1 << (h - 1);for (int i = 0; i < p; i++) {auto ajp = a[i] - a[i + p];a[i] += a[i + p];a[i + p] = ajp;}}if (f) {Mint inv_sz = Mint(1) / n;for (int i = 0; i < n; i++)a[i] *= inv_sz;}}static vector<Mint> multiply(vector<Mint> a, vector<Mint> b) {int need = a.size() + b.size() - 1;int nbase = 1;while ((1 << nbase) < need)nbase++;int sz = 1 << nbase;a.resize(sz, 0);b.resize(sz, 0);ntt(a);ntt(b);Mint inv_sz = Mint(1) / sz;for (int i = 0; i < sz; i++)a[i] *= b[i] * inv_sz;intt(a, false);a.resize(need);return a;}};template <typename Mint>vector<Mint>NumberTheoreticTransformFriendlyModInt<Mint>::roots = vector<Mint>();template <typename Mint>vector<Mint>NumberTheoreticTransformFriendlyModInt<Mint>::iroots = vector<Mint>();template <typename Mint>vector<Mint>NumberTheoreticTransformFriendlyModInt<Mint>::rate3 = vector<Mint>();template <typename Mint>vector<Mint>NumberTheoreticTransformFriendlyModInt<Mint>::irate3 = vector<Mint>();template <typename Mint>int NumberTheoreticTransformFriendlyModInt<Mint>::max_base = 0;// #line 4 "math/fps/formal-power-series-friendly-ntt.hpp"template <typename T> struct FormalPowerSeriesFriendlyNTT : vector<T> {using vector<T>::vector;using P = FormalPowerSeriesFriendlyNTT;using NTT = NumberTheoreticTransformFriendlyModInt<T>;P pre(int deg) const {return P(begin(*this), begin(*this) + min((int)this->size(), deg));}P rev(int deg = -1) const {P ret(*this);if (deg != -1)ret.resize(deg, T(0));reverse(begin(ret), end(ret));return ret;}void shrink() {while (this->size() && this->back() == T(0))this->pop_back();}P operator+(const P& r) const { return P(*this) += r; }P operator+(const T& v) const { return P(*this) += v; }P operator-(const P& r) const { return P(*this) -= r; }P operator-(const T& v) const { return P(*this) -= v; }P operator*(const P& r) const { return P(*this) *= r; }P operator*(const T& v) const { return P(*this) *= v; }P operator/(const P& r) const { return P(*this) /= r; }P operator%(const P& r) const { return P(*this) %= r; }P& operator+=(const P& r) {if (r.size() > this->size())this->resize(r.size());for (int i = 0; i < (int)r.size(); i++)(*this)[i] += r[i];return *this;}P& operator-=(const P& r) {if (r.size() > this->size())this->resize(r.size());for (int i = 0; i < (int)r.size(); i++)(*this)[i] -= r[i];return *this;}// https://judge.yosupo.jp/problem/convolution_modP& operator*=(const P& r) {if (this->empty() || r.empty()) {this->clear();return *this;}auto ret = NTT::multiply(*this, r);return *this = { begin(ret), end(ret) };}P& operator/=(const P& r) {if (this->size() < r.size()) {this->clear();return *this;}int n = this->size() - r.size() + 1;return *this = (rev().pre(n) * r.rev().inv(n)).pre(n).rev(n);}P& operator%=(const P& r) {*this -= *this / r * r;shrink();return *this;}// https://judge.yosupo.jp/problem/division_of_polynomialspair<P, P> div_mod(const P& r) {P q = *this / r;P x = *this - q * r;x.shrink();return make_pair(q, x);}P operator-() const {P ret(this->size());for (int i = 0; i < (int)this->size(); i++)ret[i] = -(*this)[i];return ret;}P& operator+=(const T& r) {if (this->empty())this->resize(1);(*this)[0] += r;return *this;}P& operator-=(const T& r) {if (this->empty())this->resize(1);(*this)[0] -= r;return *this;}P& operator*=(const T& v) {for (int i = 0; i < (int)this->size(); i++)(*this)[i] *= v;return *this;}P dot(P r) const {P ret(min(this->size(), r.size()));for (int i = 0; i < (int)ret.size(); i++)ret[i] = (*this)[i] * r[i];return ret;}P operator>>(int sz) const {if ((int)this->size() <= sz)return {};P ret(*this);ret.erase(ret.begin(), ret.begin() + sz);return ret;}P operator<<(int sz) const {P ret(*this);ret.insert(ret.begin(), sz, T(0));return ret;}T operator()(T x) const {T r = 0, w = 1;for (auto& v : *this) {r += w * v;w *= x;}return r;}P diff() const {const int n = (int)this->size();P ret(max(0, n - 1));for (int i = 1; i < n; i++)ret[i - 1] = (*this)[i] * T(i);return ret;}P integral() const {const int n = (int)this->size();P ret(n + 1);ret[0] = T(0);for (int i = 0; i < n; i++)ret[i + 1] = (*this)[i] / T(i + 1);return ret;}// https://judge.yosupo.jp/problem/inv_of_formal_power_series// F(0) must not be 0P inv(int deg = -1) const {assert(((*this)[0]) != T(0));const int n = (int)this->size();if (deg == -1)deg = n;P res(deg);res[0] = { T(1) / (*this)[0] };for (int d = 1; d < deg; d <<= 1) {P f(2 * d), g(2 * d);for (int j = 0; j < min(n, 2 * d); j++)f[j] = (*this)[j];for (int j = 0; j < d; j++)g[j] = res[j];NTT::ntt(f);NTT::ntt(g);f = f.dot(g);NTT::intt(f);for (int j = 0; j < d; j++)f[j] = 0;NTT::ntt(f);for (int j = 0; j < 2 * d; j++)f[j] *= g[j];NTT::intt(f);for (int j = d; j < min(2 * d, deg); j++)res[j] = -f[j];}return res;}// https://judge.yosupo.jp/problem/log_of_formal_power_series// F(0) must be 1P log(int deg = -1) const {assert((*this)[0] == T(1));const int n = (int)this->size();if (deg == -1)deg = n;return (this->diff() * this->inv(deg)).pre(deg - 1).integral();}// https://judge.yosupo.jp/problem/sqrt_of_formal_power_seriesP sqrt(int deg = -1,const function<T(T)>& get_sqrt = [](T) { return T(1); }) const {const int n = (int)this->size();if (deg == -1)deg = n;if ((*this)[0] == T(0)) {for (int i = 1; i < n; i++) {if ((*this)[i] != T(0)) {if (i & 1)return {};if (deg - i / 2 <= 0)break;auto ret = (*this >> i).sqrt(deg - i / 2, get_sqrt);if (ret.empty())return {};ret = ret << (i / 2);if ((int)ret.size() < deg)ret.resize(deg, T(0));return ret;}}return P(deg, 0);}auto sqr = T(get_sqrt((*this)[0]));if (sqr * sqr != (*this)[0])return {};P ret{ sqr };T inv2 = T(1) / T(2);for (int i = 1; i < deg; i <<= 1) {ret = (ret + pre(i << 1) * ret.inv(i << 1)) * inv2;}return ret.pre(deg);}P sqrt(const function<T(T)>& get_sqrt, int deg = -1) const {return sqrt(deg, get_sqrt);}// https://judge.yosupo.jp/problem/exp_of_formal_power_series// F(0) must be 0P exp(int deg = -1) const {if (deg == -1)deg = this->size();assert((*this)[0] == T(0));P inv;inv.reserve(deg + 1);inv.push_back(T(0));inv.push_back(T(1));auto inplace_integral = [&](P& F) -> void {const int n = (int)F.size();auto mod = T::get_mod();while ((int)inv.size() <= n) {int i = inv.size();inv.push_back((-inv[mod % i]) * (mod / i));}F.insert(begin(F), T(0));for (int i = 1; i <= n; i++)F[i] *= inv[i];};auto inplace_diff = [](P& F) -> void {if (F.empty())return;F.erase(begin(F));T coeff = 1, one = 1;for (int i = 0; i < (int)F.size(); i++) {F[i] *= coeff;coeff += one;}};P b{ 1, 1 < (int)this->size() ? (*this)[1] : 0 }, c{ 1 }, z1, z2{ 1, 1 };for (int m = 2; m < deg; m *= 2) {auto y = b;y.resize(2 * m);NTT::ntt(y);z1 = z2;P z(m);for (int i = 0; i < m; ++i)z[i] = y[i] * z1[i];NTT::intt(z);fill(begin(z), begin(z) + m / 2, T(0));NTT::ntt(z);for (int i = 0; i < m; ++i)z[i] *= -z1[i];NTT::intt(z);c.insert(end(c), begin(z) + m / 2, end(z));z2 = c;z2.resize(2 * m);NTT::ntt(z2);P x(begin(*this), begin(*this) + min<int>(this->size(), m));inplace_diff(x);x.push_back(T(0));NTT::ntt(x);for (int i = 0; i < m; ++i)x[i] *= y[i];NTT::intt(x);x -= b.diff();x.resize(2 * m);for (int i = 0; i < m - 1; ++i)x[m + i] = x[i], x[i] = T(0);NTT::ntt(x);for (int i = 0; i < 2 * m; ++i)x[i] *= z2[i];NTT::intt(x);x.pop_back();inplace_integral(x);for (int i = m; i < min<int>(this->size(), 2 * m); ++i)x[i] += (*this)[i];fill(begin(x), begin(x) + m, T(0));NTT::ntt(x);for (int i = 0; i < 2 * m; ++i)x[i] *= y[i];NTT::intt(x);b.insert(end(b), begin(x) + m, end(x));}return P{ begin(b), begin(b) + deg };}// https://judge.yosupo.jp/problem/pow_of_formal_power_seriesP pow(int64_t k, int deg = -1) const {const int n = (int)this->size();if (deg == -1)deg = n;if (k == 0) {P ret(deg, T(0));ret[0] = T(1);return ret;}for (int i = 0; i < n; i++) {if (i * k > deg)return P(deg, T(0));if ((*this)[i] != T(0)) {T rev = T(1) / (*this)[i];P ret = (((*this * rev) >> i).log() * k).exp() * ((*this)[i].pow(k));ret = (ret << (i * k)).pre(deg);if ((int)ret.size() < deg)ret.resize(deg, T(0));return ret;}}return *this;}P mod_pow(int64_t k, P g) const {P modinv = g.rev().inv();auto get_div = [&](P base) {if (base.size() < g.size()) {base.clear();return base;}int n = base.size() - g.size() + 1;return (base.rev().pre(n) * modinv.pre(n)).pre(n).rev(n);};P x(*this), ret{ 1 };while (k > 0) {if (k & 1) {ret *= x;ret -= get_div(ret) * g;ret.shrink();}x *= x;x -= get_div(x) * g;x.shrink();k >>= 1;}return ret;}// https://judge.yosupo.jp/problem/polynomial_taylor_shiftP taylor_shift(T c) const {int n = (int)this->size();vector<T> fact(n), rfact(n);fact[0] = rfact[0] = T(1);for (int i = 1; i < n; i++)fact[i] = fact[i - 1] * T(i);rfact[n - 1] = T(1) / fact[n - 1];for (int i = n - 1; i > 1; i--)rfact[i - 1] = rfact[i] * T(i);P p(*this);for (int i = 0; i < n; i++)p[i] *= fact[i];p = p.rev();P bs(n, T(1));for (int i = 1; i < n; i++)bs[i] = bs[i - 1] * c * rfact[i] * fact[i - 1];p = (p * bs).pre(n);p = p.rev();for (int i = 0; i < n; i++)p[i] *= rfact[i];return p;}};template <typename Mint> using FPS = FormalPowerSeriesFriendlyNTT<Mint>;// #line 8 "test/verify/yosupo-exp-of-formal-power-series.test.cpp"/*** @brief Enumeration(組み合わせ)*/template <typename T> struct Enumeration {private:static vector<T> _fact, _finv, _inv;inline static void expand(size_t sz) {if (_fact.size() < sz + 1) {int pre_sz = max(1, (int)_fact.size());_fact.resize(sz + 1, T(1));_finv.resize(sz + 1, T(1));_inv.resize(sz + 1, T(1));for (int i = pre_sz; i <= (int)sz; i++) {_fact[i] = _fact[i - 1] * T(i);}_finv[sz] = T(1) / _fact[sz];for (int i = (int)sz - 1; i >= pre_sz; i--) {_finv[i] = _finv[i + 1] * T(i + 1);}for (int i = pre_sz; i <= (int)sz; i++) {_inv[i] = _finv[i] * _fact[i - 1];}}}public:explicit Enumeration(size_t sz = 0) { expand(sz); }static inline T fact(int k) {expand(k);return _fact[k];}static inline T finv(int k) {expand(k);return _finv[k];}static inline T inv(int k) {expand(k);return _inv[k];}static T P(int n, int r) {if (r < 0 || n < r)return 0;return fact(n) * finv(n - r);}static T C(int p, int q) {if (q < 0 || p < q)return 0;return fact(p) * finv(q) * finv(p - q);}static T H(int n, int r) {if (n < 0 || r < 0)return 0;return r == 0 ? 1 : C(n + r - 1, r);}};template <typename T> vector<T> Enumeration<T>::_fact = vector<T>();template <typename T> vector<T> Enumeration<T>::_finv = vector<T>();template <typename T> vector<T> Enumeration<T>::_inv = vector<T>();const int MOD = 998244353;using mint = ModInt<MOD>;using enu = Enumeration<mint>;using poly = FPS<mint>;poly shift(poly f, int w) {f.insert(f.begin(), w, mint(0));return f;}poly div(poly a, poly b, int n) {while (!b.empty() && b.back() == 0) b.pop_back();assert(!b.empty());a /= b;a.resize(n);return a;}struct Path {poly D, L, R, P;};struct Point {poly D, P;};struct ToptreeDP {vector<set<int>> g;vector<int> sz;ToptreeDP(vector<set<int>> g_) : g(g_), sz(g_.size()) {}Path solve_path(int S, int T) {if (S == T) {for (int c : g[S])g[c].erase(S);Point p = solve_point(g[S]);Path ret;ret.D = p.D - shift(p.D, 1) - shift(p.P, 2);ret.L = ret.R = ret.P = p.D;return ret;}{auto f = [&](auto& f, int v, int p) -> void {sz[v] = 1;for (int c : g[v])if (c != p)f(f, c, v), sz[v] += sz[c];};f(f, S, -1);}int N = sz[S];int sep = S, par = -1;auto score = [&](int v) { return (long long)(sz[v]) * (N - sz[v]); };{auto f = [&](auto& f, int v, int p) -> bool {bool ret = v == T;for (int c : g[v])if (c != p)ret |= f(f, c, v);if (ret && score(v) > score(sep))sep = v, par = p;return ret;};f(f, S, -1);}g[sep].erase(par);g[par].erase(sep);Path X = solve_path(S, par);Path Y = solve_path(sep, T);Path ret;ret.D = X.D * Y.D - shift(X.R * Y.L, 2);ret.P = shift(X.P * Y.P, 1);ret.L = div(X.L * ret.D + shift(X.P * X.P * Y.L, 2), X.D, ret.D.size() - 1);ret.R = div(Y.R * ret.D + shift(Y.P * Y.P * X.R, 2), Y.D, ret.D.size() - 1);return ret;}Point solve_point(set<int> V) {if (V.size() == 0) {Point ret;ret.D = { 1 };ret.P = {};return ret;}if (V.size() == 1) {int S = *V.begin(), T = -1;{auto f = [&](auto& f, int v, int p) -> int {sz[v] = 1;int ret = v;int ma = 0;for (int c : g[v]) {if (c != p) {int t = f(f, c, v);sz[v] += sz[c];if (ma < sz[c])ma = sz[c], ret = t;}}return ret;};T = f(f, S, -1);}Path p = solve_path(S, T);Point ret;ret.D = p.D;ret.P = p.L;return ret;}vector<pair<int, int>> sz_vt;auto f = [&](auto& f, int v, int p) -> int {int ret = 1;for (int c : g[v])if (c != p)ret += f(f, c, v);return ret;};for (int v : V) {sz_vt.push_back({ f(f, v, -1), v });}sort(sz_vt.rbegin(), sz_vt.rend());int ls = 0, rs = 0;set<int> lset, rset;for (auto& [s, v] : sz_vt) {if (ls < rs) {ls += s;lset.insert(v);}else {rs += s;rset.insert(v);}}Point X = solve_point(lset);Point Y = solve_point(rset);Point ret;ret.D = X.D * Y.D;ret.P = X.P * Y.D + Y.P * X.D;return ret;}};mint bm(poly n, poly d, int k) {while (k) {poly md = d;for (int i = 0; i < md.size(); i += 2)md[i] = -md[i];n *= md;d *= md;if (k & 1)n.erase(n.begin());poly nn, dd;for (int i = 0; i < n.size(); i += 2)nn.push_back(n[i]);for (int i = 0; i < d.size(); i += 2)dd.push_back(d[i]);n = nn;d = dd;k /= 2;}return n[0] * d[0].inverse();}int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int N, M, S, T;cin >> N >> M >> S >> T;S--, T--;vector<set<int>> g(N);for (int i = 0; i < N - 1; i++) {int A, B;cin >> A >> B;A--, B--;g[A].insert(B);g[B].insert(A);}ToptreeDP t(g);auto ans = t.solve_path(S, T);cout << bm(ans.P, ans.D, M) << "\n";return 0;}