結果
| 問題 |
No.2587 Random Walk on Tree
|
| コンテスト | |
| ユーザー |
tko919
|
| 提出日時 | 2023-12-25 04:47:05 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 34,477 bytes |
| コンパイル時間 | 4,786 ms |
| コンパイル使用メモリ | 259,496 KB |
| 最終ジャッジ日時 | 2025-02-18 14:27:37 |
|
ジャッジサーバーID (参考情報) |
judge3 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | WA * 12 TLE * 1 -- * 24 |
ソースコード
#line 1 "library/Template/template.hpp"
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(int)(a);i<(int)(b);i++)
#define ALL(v) (v).begin(),(v).end()
#define UNIQUE(v) sort(ALL(v)),(v).erase(unique(ALL(v)),(v).end())
#define SZ(v) (int)v.size()
#define MIN(v) *min_element(ALL(v))
#define MAX(v) *max_element(ALL(v))
#define LB(v,x) int(lower_bound(ALL(v),(x))-(v).begin())
#define UB(v,x) int(upper_bound(ALL(v),(x))-(v).begin())
using ll=long long int;
using ull=unsigned long long;
const int inf = 0x3fffffff;
const ll INF = 0x1fffffffffffffff;
template<typename T>inline bool chmax(T& a,T b){if(a<b){a=b;return 1;}return 0;}
template<typename T>inline bool chmin(T& a,T b){if(a>b){a=b;return 1;}return 0;}
template<typename T,typename U>T ceil(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?(x+y-1)/y:x/y);}
template<typename T,typename U>T floor(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?x/y:(x-y+1)/y);}
template<typename T>int popcnt(T x){return __builtin_popcountll(x);}
template<typename T>int topbit(T x){return (x==0?-1:63-__builtin_clzll(x));}
template<typename T>int lowbit(T x){return (x==0?-1:__builtin_ctzll(x));}
#line 2 "library/Utility/fastio.hpp"
#include <unistd.h>
class FastIO {
static constexpr int L = 1 << 16;
char rdbuf[L];
int rdLeft = 0, rdRight = 0;
inline void reload() {
int len = rdRight - rdLeft;
memmove(rdbuf, rdbuf + rdLeft, len);
rdLeft = 0, rdRight = len;
rdRight += fread(rdbuf + len, 1, L - len, stdin);
}
inline bool skip() {
for (;;) {
while (rdLeft != rdRight and rdbuf[rdLeft] <= ' ')
rdLeft++;
if (rdLeft == rdRight) {
reload();
if (rdLeft == rdRight)
return false;
} else
break;
}
return true;
}
template <typename T, enable_if_t<is_integral<T>::value, int> = 0>
inline bool _read(T &x) {
if (!skip())
return false;
if (rdLeft + 20 >= rdRight)
reload();
bool neg = false;
if (rdbuf[rdLeft] == '-') {
neg = true;
rdLeft++;
}
x = 0;
while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
x = x * 10 +
(neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48));
}
return true;
}
inline bool _read(__int128_t &x) {
if (!skip())
return false;
if (rdLeft + 40 >= rdRight)
reload();
bool neg = false;
if (rdbuf[rdLeft] == '-') {
neg = true;
rdLeft++;
}
x = 0;
while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
x = x * 10 +
(neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48));
}
return true;
}
inline bool _read(__uint128_t &x) {
if (!skip())
return false;
if (rdLeft + 40 >= rdRight)
reload();
x = 0;
while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
x = x * 10 + (rdbuf[rdLeft++] ^ 48);
}
return true;
}
template <typename T, enable_if_t<is_floating_point<T>::value, int> = 0>
inline bool _read(T &x) {
if (!skip())
return false;
if (rdLeft + 20 >= rdRight)
reload();
bool neg = false;
if (rdbuf[rdLeft] == '-') {
neg = true;
rdLeft++;
}
x = 0;
while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and
rdLeft < rdRight) {
x = x * 10 + (rdbuf[rdLeft++] ^ 48);
}
if (rdbuf[rdLeft] != '.')
return true;
rdLeft++;
T base = .1;
while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and
rdLeft < rdRight) {
x += base * (rdbuf[rdLeft++] ^ 48);
base *= .1;
}
if (neg)
x = -x;
return true;
}
inline bool _read(char &x) {
if (!skip())
return false;
if (rdLeft + 1 >= rdRight)
reload();
x = rdbuf[rdLeft++];
return true;
}
inline bool _read(string &x) {
if (!skip())
return false;
for (;;) {
int pos = rdLeft;
while (pos < rdRight and rdbuf[pos] > ' ')
pos++;
x.append(rdbuf + rdLeft, pos - rdLeft);
if (rdLeft == pos)
break;
rdLeft = pos;
if (rdLeft == rdRight)
reload();
else
break;
}
return true;
}
template <typename T> inline bool _read(vector<T> &v) {
for (auto &x : v) {
if (!_read(x))
return false;
}
return true;
}
char wtbuf[L], tmp[50];
int wtRight = 0;
inline void _write(const char &x) {
if (wtRight > L - 32)
flush();
wtbuf[wtRight++] = x;
}
inline void _write(const string &x) {
for (auto &c : x)
_write(c);
}
template <typename T, enable_if_t<is_integral<T>::value, int> = 0>
inline void _write(T x) {
if (wtRight > L - 32)
flush();
if (x == 0) {
_write('0');
return;
} else if (x < 0) {
_write('-');
if (__builtin_expect(x == std::numeric_limits<T>::min(), 0)) {
switch (sizeof(x)) {
case 2:
_write("32768");
return;
case 4:
_write("2147483648");
return;
case 8:
_write("9223372036854775808");
return;
}
}
x = -x;
}
int pos = 0;
while (x != 0) {
tmp[pos++] = char((x % 10) | 48);
x /= 10;
}
rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
wtRight += pos;
}
inline void _write(__int128_t x) {
if (wtRight > L - 40)
flush();
if (x == 0) {
_write('0');
return;
} else if (x < 0) {
_write('-');
x = -x;
}
int pos = 0;
while (x != 0) {
tmp[pos++] = char((x % 10) | 48);
x /= 10;
}
rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
wtRight += pos;
}
inline void _write(__uint128_t x) {
if (wtRight > L - 40)
flush();
if (x == 0) {
_write('0');
return;
}
int pos = 0;
while (x != 0) {
tmp[pos++] = char((x % 10) | 48);
x /= 10;
}
rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
wtRight += pos;
}
inline void _write(double x) {
ostringstream oss;
oss << fixed << setprecision(15) << double(x);
string s = oss.str();
_write(s);
}
template <typename T> inline void _write(const vector<T> &v) {
rep(i, 0, v.size()) {
if (i)
_write(' ');
_write(v[i]);
}
}
public:
FastIO() {}
~FastIO() { flush(); }
inline void read() {}
template <typename Head, typename... Tail>
inline void read(Head &head, Tail &...tail) {
assert(_read(head));
read(tail...);
}
template <bool ln = true, bool space = false> inline void write() {
if (ln)
_write('\n');
}
template <bool ln = true, bool space = true, typename Head,
typename... Tail>
inline void write(const Head &head, const Tail &...tail) {
_write(head);
if (space)
_write(' ');
write<ln, true>(tail...);
}
inline void flush() {
fwrite(wtbuf, 1, wtRight, stdout);
wtRight = 0;
}
};
/**
* @brief Fast IO
*/
#line 3 "sol.cpp"
#line 2 "library/Convolution/ntt.hpp"
template <typename T> struct NTT {
static constexpr int rank2 = __builtin_ctzll(T::get_mod() - 1);
std::array<T, rank2 + 1> root; // root[i]^(2^i) == 1
std::array<T, rank2 + 1> iroot; // root[i] * iroot[i] == 1
std::array<T, std::max(0, rank2 - 2 + 1)> rate2;
std::array<T, std::max(0, rank2 - 2 + 1)> irate2;
std::array<T, std::max(0, rank2 - 3 + 1)> rate3;
std::array<T, std::max(0, rank2 - 3 + 1)> irate3;
NTT() {
T g = 2;
while (g.pow((T::get_mod() - 1) >> 1) == 1) {
g += 1;
}
root[rank2] = g.pow((T::get_mod() - 1) >> rank2);
iroot[rank2] = root[rank2].inv();
for (int i = rank2 - 1; i >= 0; i--) {
root[i] = root[i + 1] * root[i + 1];
iroot[i] = iroot[i + 1] * iroot[i + 1];
}
{
T prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 2; i++) {
rate2[i] = root[i + 2] * prod;
irate2[i] = iroot[i + 2] * iprod;
prod *= iroot[i + 2];
iprod *= root[i + 2];
}
}
{
T prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 3; i++) {
rate3[i] = root[i + 3] * prod;
irate3[i] = iroot[i + 3] * iprod;
prod *= iroot[i + 3];
iprod *= root[i + 3];
}
}
}
void ntt(std::vector<T> &a, bool type = 0) {
int n = int(a.size());
int h = __builtin_ctzll((unsigned int)n);
if (type) {
int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len) {
if (len == 1) {
int p = 1 << (h - len);
T irot = 1;
for (int s = 0; s < (1 << (len - 1)); s++) {
int offset = s << (h - len + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(T::get_mod() + l.v - r.v) *
irot.v;
;
}
if (s + 1 != (1 << (len - 1)))
irot *= irate2[__builtin_ctzll(~(unsigned int)(s))];
}
len--;
} else {
// 4-base
int p = 1 << (h - len);
T irot = 1, iimag = iroot[2];
for (int s = 0; s < (1 << (len - 2)); s++) {
T irot2 = irot * irot;
T irot3 = irot2 * irot;
int offset = s << (h - len + 2);
for (int i = 0; i < p; i++) {
auto a0 = 1ULL * a[i + offset + 0 * p].v;
auto a1 = 1ULL * a[i + offset + 1 * p].v;
auto a2 = 1ULL * a[i + offset + 2 * p].v;
auto a3 = 1ULL * a[i + offset + 3 * p].v;
auto a2na3iimag =
1ULL * T((T::get_mod() + a2 - a3) * iimag.v).v;
a[i + offset] = a0 + a1 + a2 + a3;
a[i + offset + 1 * p] =
(a0 + (T::get_mod() - a1) + a2na3iimag) *
irot.v;
a[i + offset + 2 * p] =
(a0 + a1 + (T::get_mod() - a2) +
(T::get_mod() - a3)) *
irot2.v;
a[i + offset + 3 * p] =
(a0 + (T::get_mod() - a1) +
(T::get_mod() - a2na3iimag)) *
irot3.v;
}
if (s + 1 != (1 << (len - 2)))
irot *= irate3[__builtin_ctzll(~(unsigned int)(s))];
}
len -= 2;
}
}
T e = T(n).inv();
for (auto &x : a)
x *= e;
} else {
int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len < h) {
if (h - len == 1) {
int p = 1 << (h - len - 1);
T rot = 1;
for (int s = 0; s < (1 << len); s++) {
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * rot;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
if (s + 1 != (1 << len))
rot *= rate2[__builtin_ctzll(~(unsigned int)(s))];
}
len++;
} else {
// 4-base
int p = 1 << (h - len - 2);
T rot = 1, imag = root[2];
for (int s = 0; s < (1 << len); s++) {
T rot2 = rot * rot;
T rot3 = rot2 * rot;
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto mod2 = 1ULL * T::get_mod() * T::get_mod();
auto a0 = 1ULL * a[i + offset].v;
auto a1 = 1ULL * a[i + offset + p].v * rot.v;
auto a2 = 1ULL * a[i + offset + 2 * p].v * rot2.v;
auto a3 = 1ULL * a[i + offset + 3 * p].v * rot3.v;
auto a1na3imag =
1ULL * T(a1 + mod2 - a3).v * imag.v;
auto na2 = mod2 - a2;
a[i + offset] = a0 + a2 + a1 + a3;
a[i + offset + 1 * p] =
a0 + a2 + (2 * mod2 - (a1 + a3));
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
a[i + offset + 3 * p] =
a0 + na2 + (mod2 - a1na3imag);
}
if (s + 1 != (1 << len))
rot *= rate3[__builtin_ctzll(~(unsigned int)(s))];
}
len += 2;
}
}
}
}
vector<T> mult(const vector<T> &a, const vector<T> &b) {
if (a.empty() or b.empty())
return vector<T>();
int as = a.size(), bs = b.size();
int n = as + bs - 1;
if (as <= 30 or bs <= 30) {
if (as > 30)
return mult(b, a);
vector<T> res(n);
rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j];
return res;
}
int m = 1;
while (m < n)
m <<= 1;
vector<T> res(m);
rep(i, 0, as) res[i] = a[i];
ntt(res);
if (a == b)
rep(i, 0, m) res[i] *= res[i];
else {
vector<T> c(m);
rep(i, 0, bs) c[i] = b[i];
ntt(c);
rep(i, 0, m) res[i] *= c[i];
}
ntt(res, 1);
res.resize(n);
return res;
}
};
/**
* @brief Number Theoretic Transform
*/
#line 2 "library/FPS/fps.hpp"
template <typename T> struct Poly : vector<T> {
Poly(int n = 0) { this->assign(n, T()); }
Poly(const initializer_list<T> f) : vector<T>::vector(f) {}
Poly(const vector<T> &f) { this->assign(ALL(f)); }
T eval(const T &x) {
T res;
for (int i = this->size() - 1; i >= 0; i--)
res *= x, res += this->at(i);
return res;
}
Poly rev() const {
Poly res = *this;
reverse(ALL(res));
return res;
}
void shrink() {
while (!this->empty() and this->back() == 0)
this->pop_back();
}
Poly operator>>(int sz) const {
if ((int)this->size() <= sz)
return {};
Poly ret(*this);
ret.erase(ret.begin(), ret.begin() + sz);
return ret;
}
Poly operator<<(int sz) const {
Poly ret(*this);
ret.insert(ret.begin(), sz, T(0));
return ret;
}
Poly<T> mult(const Poly<T> &a, const Poly<T> &b) {
if (a.empty() or b.empty())
return {};
int as = a.size(), bs = b.size();
int n = as + bs - 1;
if (as <= 30 or bs <= 30) {
if (as > 30)
return mult(b, a);
Poly<T> res(n);
rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j];
return res;
}
int m = 1;
while (m < n)
m <<= 1;
Poly<T> res(m);
rep(i, 0, as) res[i] = a[i];
NTT(res, 0);
if (a == b)
rep(i, 0, m) res[i] *= res[i];
else {
Poly<T> c(m);
rep(i, 0, bs) c[i] = b[i];
NTT(c, 0);
rep(i, 0, m) res[i] *= c[i];
}
NTT(res, 1);
res.resize(n);
return res;
}
Poly square() const { return Poly(mult(*this, *this)); }
Poly operator-() const { return Poly() - *this; }
Poly operator+(const Poly &g) const { return Poly(*this) += g; }
Poly operator+(const T &g) const { return Poly(*this) += g; }
Poly operator-(const Poly &g) const { return Poly(*this) -= g; }
Poly operator-(const T &g) const { return Poly(*this) -= g; }
Poly operator*(const Poly &g) const { return Poly(*this) *= g; }
Poly operator*(const T &g) const { return Poly(*this) *= g; }
Poly operator/(const Poly &g) const { return Poly(*this) /= g; }
Poly operator/(const T &g) const { return Poly(*this) /= g; }
Poly operator%(const Poly &g) const { return Poly(*this) %= g; }
pair<Poly, Poly> divmod(const Poly &g) const {
Poly q = *this / g, r = *this - g * q;
r.shrink();
return {q, r};
}
Poly &operator+=(const Poly &g) {
if (g.size() > this->size())
this->resize(g.size());
rep(i, 0, g.size()) { (*this)[i] += g[i]; }
return *this;
}
Poly &operator+=(const T &g) {
if (this->empty())
this->push_back(0);
(*this)[0] += g;
return *this;
}
Poly &operator-=(const Poly &g) {
if (g.size() > this->size())
this->resize(g.size());
rep(i, 0, g.size()) { (*this)[i] -= g[i]; }
return *this;
}
Poly &operator-=(const T &g) {
if (this->empty())
this->push_back(0);
(*this)[0] -= g;
return *this;
}
Poly &operator*=(const Poly &g) {
*this = mult(*this, g);
return *this;
}
Poly &operator*=(const T &g) {
rep(i, 0, this->size())(*this)[i] *= g;
return *this;
}
Poly &operator/=(const Poly &g) {
if (g.size() > this->size()) {
this->clear();
return *this;
}
Poly g2 = g;
reverse(ALL(*this));
reverse(ALL(g2));
int n = this->size() - g2.size() + 1;
this->resize(n);
g2.resize(n);
*this *= g2.inv();
this->resize(n);
reverse(ALL(*this));
shrink();
return *this;
}
Poly &operator/=(const T &g) {
rep(i, 0, this->size())(*this)[i] /= g;
return *this;
}
Poly &operator%=(const Poly &g) {
*this -= *this / g * g;
shrink();
return *this;
}
Poly diff() const {
Poly res(this->size() - 1);
rep(i, 0, res.size()) res[i] = (*this)[i + 1] * (i + 1);
return res;
}
Poly inte() const {
Poly res(this->size() + 1);
for (int i = res.size() - 1; i; i--)
res[i] = (*this)[i - 1] / i;
return res;
}
Poly log() const {
assert(this->front() == 1);
const int n = this->size();
Poly res = diff() * inv();
res = res.inte();
res.resize(n);
return res;
}
Poly shift(const int &c) const {
const int n = this->size();
Poly res = *this, g(n);
g[0] = 1;
rep(i, 1, n) g[i] = g[i - 1] * c / i;
vector<T> fact(n, 1);
rep(i, 0, n) {
if (i)
fact[i] = fact[i - 1] * i;
res[i] *= fact[i];
}
res = res.rev();
res *= g;
res.resize(n);
res = res.rev();
rep(i, 0, n) res[i] /= fact[i];
return res;
}
Poly inv() const {
const int n = this->size();
Poly res(1);
res.front() = T(1) / this->front();
for (int k = 1; k < n; k <<= 1) {
Poly f(k * 2), g(k * 2);
rep(i, 0, min(n, k * 2)) f[i] = (*this)[i];
rep(i, 0, k) g[i] = res[i];
NTT(f, 0);
NTT(g, 0);
rep(i, 0, k * 2) f[i] *= g[i];
NTT(f, 1);
rep(i, 0, k) {
f[i] = 0;
f[i + k] = -f[i + k];
}
NTT(f, 0);
rep(i, 0, k * 2) f[i] *= g[i];
NTT(f, 1);
rep(i, 0, k) f[i] = res[i];
swap(res, f);
}
res.resize(n);
return res;
}
Poly exp() const {
const int n = this->size();
if (n == 1)
return Poly({T(1)});
Poly b(2), c(1), z1, z2(2);
b[0] = c[0] = z2[0] = z2[1] = 1;
b[1] = (*this)[1];
for (int k = 2; k < n; k <<= 1) {
Poly y = b;
y.resize(k * 2);
NTT(y, 0);
z1 = z2;
Poly z(k);
rep(i, 0, k) z[i] = y[i] * z1[i];
NTT(z, 1);
rep(i, 0, k >> 1) z[i] = 0;
NTT(z, 0);
rep(i, 0, k) z[i] *= -z1[i];
NTT(z, 1);
c.insert(c.end(), z.begin() + (k >> 1), z.end());
z2 = c;
z2.resize(k * 2);
NTT(z2, 0);
Poly x = *this;
x.resize(k);
x = x.diff();
x.resize(k);
NTT(x, 0);
rep(i, 0, k) x[i] *= y[i];
NTT(x, 1);
Poly bb = b.diff();
rep(i, 0, k - 1) x[i] -= bb[i];
x.resize(k * 2);
rep(i, 0, k - 1) {
x[k + i] = x[i];
x[i] = 0;
}
NTT(x, 0);
rep(i, 0, k * 2) x[i] *= z2[i];
NTT(x, 1);
x.pop_back();
x = x.inte();
rep(i, k, min(n, k * 2)) x[i] += (*this)[i];
rep(i, 0, k) x[i] = 0;
NTT(x, 0);
rep(i, 0, k * 2) x[i] *= y[i];
NTT(x, 1);
b.insert(b.end(), x.begin() + k, x.end());
}
b.resize(n);
return b;
}
Poly pow(ll t) {
if (t == 0) {
Poly res(this->size());
res[0] = 1;
return res;
}
int n = this->size(), k = 0;
while (k < n and (*this)[k] == 0)
k++;
Poly res(n);
if (__int128_t(t) * k >= n)
return res;
n -= t * k;
Poly g(n);
T c = (*this)[k], ic = c.inv();
rep(i, 0, n) g[i] = (*this)[i + k] * ic;
g = g.log();
for (auto &x : g)
x *= t;
g = g.exp();
c = c.pow(t);
rep(i, 0, n) res[i + t * k] = g[i] * c;
return res;
}
void NTT(vector<T> &a, bool inv) const;
};
/**
* @brief Formal Power Series (NTT-friendly mod)
*/
#line 2 "library/Math/modint.hpp"
template <int mod = 1000000007> struct fp {
int v;
static constexpr int get_mod() { return mod; }
int inv() const {
int tmp, a = v, b = mod, x = 1, y = 0;
while (b)
tmp = a / b, a -= tmp * b, swap(a, b), x -= tmp * y, swap(x, y);
if (x < 0) {
x += mod;
}
return x;
}
fp(ll x = 0) : v(x >= 0 ? x % mod : (mod - (-x) % mod) % mod) {}
fp operator-() const { return fp() - *this; }
fp pow(ll t) {
assert(t >= 0);
fp res = 1, b = *this;
while (t) {
if (t & 1)
res *= b;
b *= b;
t >>= 1;
}
return res;
}
fp &operator+=(const fp &x) {
if ((v += x.v) >= mod)
v -= mod;
return *this;
}
fp &operator-=(const fp &x) {
if ((v += mod - x.v) >= mod)
v -= mod;
return *this;
}
fp &operator*=(const fp &x) {
v = ll(v) * x.v % mod;
return *this;
}
fp &operator/=(const fp &x) {
v = ll(v) * x.inv() % mod;
return *this;
}
fp operator+(const fp &x) const { return fp(*this) += x; }
fp operator-(const fp &x) const { return fp(*this) -= x; }
fp operator*(const fp &x) const { return fp(*this) *= x; }
fp operator/(const fp &x) const { return fp(*this) /= x; }
bool operator==(const fp &x) const { return v == x.v; }
bool operator!=(const fp &x) const { return v != x.v; }
friend istream &operator>>(istream &is, fp &x) { return is >> x.v; }
friend ostream &operator<<(ostream &os, const fp &x) { return os << x.v; }
};
template <typename T> T Inv(ll n) {
static const int md = T::get_mod();
static vector<T> buf({0, 1});
assert(n > 0);
n %= md;
while (SZ(buf) <= n) {
int k = SZ(buf), q = (md + k - 1) / k;
buf.push_back(buf[k * q - md] * q);
}
return buf[n];
}
template <typename T> T Fact(ll n, bool inv = 0) {
static const int md = T::get_mod();
static vector<T> buf({1, 1}), ibuf({1, 1});
assert(n >= 0 and n < md);
while (SZ(buf) <= n) {
buf.push_back(buf.back() * SZ(buf));
ibuf.push_back(ibuf.back() * Inv<T>(SZ(ibuf)));
}
return inv ? ibuf[n] : buf[n];
}
template <typename T> T nPr(int n, int r, bool inv = 0) {
if (n < 0 || n < r || r < 0)
return 0;
return Fact<T>(n, inv) * Fact<T>(n - r, inv ^ 1);
}
template <typename T> T nCr(int n, int r, bool inv = 0) {
if (n < 0 || n < r || r < 0)
return 0;
return Fact<T>(n, inv) * Fact<T>(r, inv ^ 1) * Fact<T>(n - r, inv ^ 1);
}
template <typename T> T nHr(int n, int r, bool inv = 0) {
return nCr<T>(n + r - 1, r, inv);
}
/**
* @brief Modint
*/
#line 7 "sol.cpp"
using Fp = fp<998244353>;
NTT<Fp> ntt;
template <> void Poly<Fp>::NTT(vector<Fp> &v, bool inv) const {
return ntt.ntt(v, inv);
}
#line 2 "library/FPS/nthterm.hpp"
template<typename T>T nth(Poly<T> p,Poly<T> q,ll n){
while(n){
Poly<T> base(q),np,nq;
for(int i=1;i<(int)q.size();i+=2)base[i]=-base[i];
p*=base; q*=base;
for(int i=n&1;i<(int)p.size();i+=2)np.emplace_back(p[i]);
for(int i=0;i<(int)q.size();i+=2)nq.emplace_back(q[i]);
swap(p,np); swap(q,nq);
n>>=1;
}
return p[0]/q[0];
}
/**
* @brief Bostan-Mori Algorithm
*/
#line 2 "library/Math/matrix.hpp"
template<class T>struct Matrix{
int h,w; vector<vector<T>> val; T det;
Matrix(){}
Matrix(int n):h(n),w(n),val(vector<vector<T>>(n,vector<T>(n))){}
Matrix(int n,int m):h(n),w(m),val(vector<vector<T>>(n,vector<T>(m))){}
vector<T>& operator[](const int i){return val[i];}
Matrix& operator+=(const Matrix& m){
assert(h==m.h and w==m.w);
rep(i,0,h)rep(j,0,w)val[i][j]+=m.val[i][j];
return *this;
}
Matrix& operator-=(const Matrix& m){
assert(h==m.h and w==m.w);
rep(i,0,h)rep(j,0,w)val[i][j]-=m.val[i][j];
return *this;
}
Matrix& operator*=(const Matrix& m){
assert(w==m.h);
Matrix<T> res(h,m.w);
rep(i,0,h)rep(j,0,m.w)rep(k,0,w)res.val[i][j]+=val[i][k]*m.val[k][j];
*this=res; return *this;
}
Matrix operator+(const Matrix& m)const{return Matrix(*this)+=m;}
Matrix operator-(const Matrix& m)const{return Matrix(*this)-=m;}
Matrix operator*(const Matrix& m)const{return Matrix(*this)*=m;}
Matrix pow(ll k){
Matrix<T> res(h,h),c=*this; rep(i,0,h)res.val[i][i]=1;
while(k){if(k&1)res*=c; c*=c; k>>=1;} return res;
}
vector<int> gauss(int c=-1){
if(val.empty())return {};
if(c==-1)c=w;
int cur=0; vector<int> res; det=1;
rep(i,0,c){
if(cur==h)break;
rep(j,cur,h)if(val[j][i]!=0){
swap(val[cur],val[j]);
if(cur!=j)det*=-1;
break;
}
det*=val[cur][i];
if(val[cur][i]==0)continue;
rep(j,0,h)if(j!=cur){
T z=val[j][i]/val[cur][i];
rep(k,i,w)val[j][k]-=val[cur][k]*z;
}
res.push_back(i);
cur++;
}
return res;
}
Matrix inv(){
assert(h==w);
Matrix base(h,h*2),res(h,h);
rep(i,0,h)rep(j,0,h)base[i][j]=val[i][j];
rep(i,0,h)base[i][h+i]=1;
base.gauss(h);
det=base.det;
rep(i,0,h)rep(j,0,h)res[i][j]=base[i][h+j]/base[i][i];
return res;
}
bool operator==(const Matrix& m){
assert(h==m.h and w==m.w);
rep(i,0,h)rep(j,0,w)if(val[i][j]!=m.val[i][j])return false;
return true;
}
bool operator!=(const Matrix& m){
assert(h==m.h and w==m.w);
rep(i,0,h)rep(j,0,w)if(val[i][j]==m.val[i][j])return false;
return true;
}
friend istream& operator>>(istream& is,Matrix& m){
rep(i,0,m.h)rep(j,0,m.w)is>>m[i][j];
return is;
}
friend ostream& operator<<(ostream& os,Matrix& m){
rep(i,0,m.h){
rep(j,0,m.w)os<<m[i][j]<<(j==m.w-1 and i!=m.h-1?'\n':' ');
}
return os;
}
};
/**
* @brief Matrix
*/
#line 15 "sol.cpp"
FastIO io;
int main() {
int n, m, S, T;
io.read(n, m, S, T);
S--;
T--;
vector g(n, vector<int>());
rep(_, 0, n - 1) {
int u, v;
io.read(u, v);
u--;
v--;
g[u].push_back(v);
g[v].push_back(u);
}
vector<int> sz(n, 1);
auto dfs1 = [&](auto &dfs1, int v, int p) -> void {
int mx = -1;
rep(i, 0, SZ(g[v])) {
int to = g[v][i];
if (to == p)
continue;
dfs1(dfs1, to, v);
sz[v] += sz[to];
if (chmax(mx, sz[to]))
swap(g[v][i], g[v][0]);
}
};
dfs1(dfs1, T, -1);
using P = pair<Poly<Fp>, Poly<Fp>>;
using Mat = Matrix<Poly<Fp>>;
function<vector<P>(int, int)> rake;
function<P(int, int)> compress;
rake = [&](int v, int p) -> vector<P> {
if (SZ(g[v]) == 1 and g[v][0] == p) {
return {P{Poly<Fp>({0}), Poly<Fp>({1})}};
}
auto ret = rake(g[v][0], v);
deque<P> deq;
deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})});
rep(i, 1, SZ(g[v])) if (g[v][i] != p) {
deq.push_back(compress(g[v][i], v));
}
while (deq.size() > 1) {
auto [A1, A2] = deq.front();
deq.pop_front();
auto [B1, B2] = deq.front();
deq.pop_front();
deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2});
}
ret.push_back(deq.front());
return ret;
};
compress = [&](int v, int p) -> P {
auto fs = rake(v, p);
auto rec = [&](auto &rec, int L, int R) -> Mat {
if (R - L == 1) {
auto [f, g] = fs[L];
Mat ret(2);
ret[0][1] = g;
ret[1][0] = -(g << 2);
ret[1][1] = g - (g << 1) - (f << 2);
return ret;
}
int mid = (L + R) >> 1;
return rec(rec, L, mid) * rec(rec, mid, R);
};
auto A = rec(rec, 0, SZ(fs));
A[0][1].shrink();
A[1][1].shrink();
return P{A[0][1], A[1][1]};
};
vector<P> fs;
auto dfs2 = [&](auto &dfs2, int v, int p) -> bool {
deque<P> deq;
deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})});
bool onedge = 0;
rep(i, 0, SZ(g[v])) if (g[v][i] != p) {
if (dfs2(dfs2, g[v][i], v)) {
onedge = 1;
} else {
deq.push_back(compress(g[v][i], v));
}
}
onedge |= (v == S);
if (onedge) {
while (deq.size() > 1) {
auto [A1, A2] = deq.front();
deq.pop_front();
auto [B1, B2] = deq.front();
deq.pop_front();
deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2});
}
auto ret = deq.front();
// cerr << v << '\n';
// for (auto &v : ret.first)
// cerr << v.v << ' ';
// cerr << '\n';
// for (auto &v : ret.second)
// cerr << v.v << ' ';
// cerr << "\n\n";
fs.push_back(ret);
}
return onedge;
};
dfs2(dfs2, T, -1);
auto rec = [&](auto &rec, int L, int R) -> Mat {
if (R - L == 1) {
auto [f, g] = fs[L];
Mat ret(2);
ret[0][1] = g;
ret[1][0] = -(g << 2);
ret[1][1] = g - (g << 1) - (f << 2);
return ret;
}
int mid = (L + R) >> 1;
return rec(rec, L, mid) * rec(rec, mid, R);
};
auto A = rec(rec, 0, SZ(fs));
Poly<Fp> num, den;
den = A[1][1];
den.shrink();
deque<Poly<Fp>> deq;
for (auto &[f, g] : fs)
deq.push_back(g);
while (deq.size() > 1) {
auto A = deq.front();
deq.pop_front();
auto B = deq.front();
deq.pop_front();
deq.push_back(A * B);
}
num = deq.front();
num.shrink();
for (auto &v : num)
cerr << v.v << ' ';
cerr << '\n';
for (auto &v : den)
cerr << v.v << ' ';
cerr << "\n\n";
if (m < (SZ(fs) - 1)) {
io.write(0);
return 0;
}
Fp ret = nth(num, den, m - (SZ(fs) - 1));
io.write(ret.v);
return 0;
}
tko919