結果

問題 No.2587 Random Walk on Tree
ユーザー tko919tko919
提出日時 2023-12-25 05:43:39
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 4,170 ms / 10,000 ms
コード長 34,997 bytes
コンパイル時間 5,130 ms
コンパイル使用メモリ 270,012 KB
実行使用メモリ 100,096 KB
最終ジャッジ日時 2024-09-27 14:11:27
合計ジャッジ時間 67,184 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,812 KB
testcase_01 AC 1 ms
6,816 KB
testcase_02 AC 2 ms
6,944 KB
testcase_03 AC 2 ms
6,940 KB
testcase_04 AC 3 ms
6,944 KB
testcase_05 AC 4 ms
6,940 KB
testcase_06 AC 4 ms
6,940 KB
testcase_07 AC 2 ms
6,940 KB
testcase_08 AC 4 ms
6,940 KB
testcase_09 AC 2 ms
6,940 KB
testcase_10 AC 3 ms
6,944 KB
testcase_11 AC 7 ms
6,944 KB
testcase_12 AC 32 ms
6,940 KB
testcase_13 AC 51 ms
6,940 KB
testcase_14 AC 14 ms
6,944 KB
testcase_15 AC 2,521 ms
18,120 KB
testcase_16 AC 1,431 ms
12,540 KB
testcase_17 AC 1,531 ms
13,152 KB
testcase_18 AC 268 ms
6,940 KB
testcase_19 AC 3,574 ms
61,312 KB
testcase_20 AC 2,659 ms
18,612 KB
testcase_21 AC 3,339 ms
39,200 KB
testcase_22 AC 4,170 ms
100,096 KB
testcase_23 AC 3,393 ms
49,536 KB
testcase_24 AC 2,704 ms
19,856 KB
testcase_25 AC 1,607 ms
20,920 KB
testcase_26 AC 4,091 ms
26,284 KB
testcase_27 AC 3,914 ms
33,384 KB
testcase_28 AC 3,368 ms
20,224 KB
testcase_29 AC 3,376 ms
19,892 KB
testcase_30 AC 2,774 ms
19,188 KB
testcase_31 AC 2,857 ms
18,860 KB
testcase_32 AC 3,367 ms
18,412 KB
testcase_33 AC 2 ms
6,944 KB
testcase_34 AC 1,173 ms
47,200 KB
testcase_35 AC 1,140 ms
47,152 KB
testcase_36 AC 1,128 ms
46,960 KB
testcase_37 AC 1,761 ms
47,964 KB
testcase_38 AC 2,624 ms
18,076 KB
testcase_39 AC 2,682 ms
18,548 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#line 1 "library/Template/template.hpp"
#include <bits/stdc++.h>
using namespace std;

#define rep(i,a,b) for(int i=(int)(a);i<(int)(b);i++)
#define ALL(v) (v).begin(),(v).end()
#define UNIQUE(v) sort(ALL(v)),(v).erase(unique(ALL(v)),(v).end())
#define SZ(v) (int)v.size()
#define MIN(v) *min_element(ALL(v))
#define MAX(v) *max_element(ALL(v))
#define LB(v,x) int(lower_bound(ALL(v),(x))-(v).begin())
#define UB(v,x) int(upper_bound(ALL(v),(x))-(v).begin())

using ll=long long int;
using ull=unsigned long long;
const int inf = 0x3fffffff;
const ll INF = 0x1fffffffffffffff;

template<typename T>inline bool chmax(T& a,T b){if(a<b){a=b;return 1;}return 0;}
template<typename T>inline bool chmin(T& a,T b){if(a>b){a=b;return 1;}return 0;}
template<typename T,typename U>T ceil(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?(x+y-1)/y:x/y);}
template<typename T,typename U>T floor(T x,U y){assert(y!=0); if(y<0)x=-x,y=-y; return (x>0?x/y:(x-y+1)/y);}
template<typename T>int popcnt(T x){return __builtin_popcountll(x);}
template<typename T>int topbit(T x){return (x==0?-1:63-__builtin_clzll(x));}
template<typename T>int lowbit(T x){return (x==0?-1:__builtin_ctzll(x));}
#line 2 "library/Utility/fastio.hpp"
#include <unistd.h>

class FastIO {
    static constexpr int L = 1 << 16;
    char rdbuf[L];
    int rdLeft = 0, rdRight = 0;
    inline void reload() {
        int len = rdRight - rdLeft;
        memmove(rdbuf, rdbuf + rdLeft, len);
        rdLeft = 0, rdRight = len;
        rdRight += fread(rdbuf + len, 1, L - len, stdin);
    }
    inline bool skip() {
        for (;;) {
            while (rdLeft != rdRight and rdbuf[rdLeft] <= ' ')
                rdLeft++;
            if (rdLeft == rdRight) {
                reload();
                if (rdLeft == rdRight)
                    return false;
            } else
                break;
        }
        return true;
    }
    template <typename T, enable_if_t<is_integral<T>::value, int> = 0>
    inline bool _read(T &x) {
        if (!skip())
            return false;
        if (rdLeft + 20 >= rdRight)
            reload();
        bool neg = false;
        if (rdbuf[rdLeft] == '-') {
            neg = true;
            rdLeft++;
        }
        x = 0;
        while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
            x = x * 10 +
                (neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48));
        }
        return true;
    }
    inline bool _read(__int128_t &x) {
        if (!skip())
            return false;
        if (rdLeft + 40 >= rdRight)
            reload();
        bool neg = false;
        if (rdbuf[rdLeft] == '-') {
            neg = true;
            rdLeft++;
        }
        x = 0;
        while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
            x = x * 10 +
                (neg ? -(rdbuf[rdLeft++] ^ 48) : (rdbuf[rdLeft++] ^ 48));
        }
        return true;
    }
    inline bool _read(__uint128_t &x) {
        if (!skip())
            return false;
        if (rdLeft + 40 >= rdRight)
            reload();
        x = 0;
        while (rdbuf[rdLeft] >= '0' and rdLeft < rdRight) {
            x = x * 10 + (rdbuf[rdLeft++] ^ 48);
        }
        return true;
    }
    template <typename T, enable_if_t<is_floating_point<T>::value, int> = 0>
    inline bool _read(T &x) {
        if (!skip())
            return false;
        if (rdLeft + 20 >= rdRight)
            reload();
        bool neg = false;
        if (rdbuf[rdLeft] == '-') {
            neg = true;
            rdLeft++;
        }
        x = 0;
        while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and
               rdLeft < rdRight) {
            x = x * 10 + (rdbuf[rdLeft++] ^ 48);
        }
        if (rdbuf[rdLeft] != '.')
            return true;
        rdLeft++;
        T base = .1;
        while (rdbuf[rdLeft] >= '0' and rdbuf[rdLeft] <= '9' and
               rdLeft < rdRight) {
            x += base * (rdbuf[rdLeft++] ^ 48);
            base *= .1;
        }
        if (neg)
            x = -x;
        return true;
    }
    inline bool _read(char &x) {
        if (!skip())
            return false;
        if (rdLeft + 1 >= rdRight)
            reload();
        x = rdbuf[rdLeft++];
        return true;
    }
    inline bool _read(string &x) {
        if (!skip())
            return false;
        for (;;) {
            int pos = rdLeft;
            while (pos < rdRight and rdbuf[pos] > ' ')
                pos++;
            x.append(rdbuf + rdLeft, pos - rdLeft);
            if (rdLeft == pos)
                break;
            rdLeft = pos;
            if (rdLeft == rdRight)
                reload();
            else
                break;
        }
        return true;
    }
    template <typename T> inline bool _read(vector<T> &v) {
        for (auto &x : v) {
            if (!_read(x))
                return false;
        }
        return true;
    }

    char wtbuf[L], tmp[50];
    int wtRight = 0;
    inline void _write(const char &x) {
        if (wtRight > L - 32)
            flush();
        wtbuf[wtRight++] = x;
    }
    inline void _write(const string &x) {
        for (auto &c : x)
            _write(c);
    }
    template <typename T, enable_if_t<is_integral<T>::value, int> = 0>
    inline void _write(T x) {
        if (wtRight > L - 32)
            flush();
        if (x == 0) {
            _write('0');
            return;
        } else if (x < 0) {
            _write('-');
            if (__builtin_expect(x == std::numeric_limits<T>::min(), 0)) {
                switch (sizeof(x)) {
                case 2:
                    _write("32768");
                    return;
                case 4:
                    _write("2147483648");
                    return;
                case 8:
                    _write("9223372036854775808");
                    return;
                }
            }
            x = -x;
        }
        int pos = 0;
        while (x != 0) {
            tmp[pos++] = char((x % 10) | 48);
            x /= 10;
        }
        rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
        wtRight += pos;
    }
    inline void _write(__int128_t x) {
        if (wtRight > L - 40)
            flush();
        if (x == 0) {
            _write('0');
            return;
        } else if (x < 0) {
            _write('-');
            x = -x;
        }
        int pos = 0;
        while (x != 0) {
            tmp[pos++] = char((x % 10) | 48);
            x /= 10;
        }
        rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
        wtRight += pos;
    }
    inline void _write(__uint128_t x) {
        if (wtRight > L - 40)
            flush();
        if (x == 0) {
            _write('0');
            return;
        }
        int pos = 0;
        while (x != 0) {
            tmp[pos++] = char((x % 10) | 48);
            x /= 10;
        }
        rep(i, 0, pos) wtbuf[wtRight + i] = tmp[pos - 1 - i];
        wtRight += pos;
    }
    inline void _write(double x) {
        ostringstream oss;
        oss << fixed << setprecision(15) << double(x);
        string s = oss.str();
        _write(s);
    }
    template <typename T> inline void _write(const vector<T> &v) {
        rep(i, 0, v.size()) {
            if (i)
                _write(' ');
            _write(v[i]);
        }
    }

  public:
    FastIO() {}
    ~FastIO() { flush(); }
    inline void read() {}
    template <typename Head, typename... Tail>
    inline void read(Head &head, Tail &...tail) {
        assert(_read(head));
        read(tail...);
    }
    template <bool ln = true, bool space = false> inline void write() {
        if (ln)
            _write('\n');
    }
    template <bool ln = true, bool space = true, typename Head,
              typename... Tail>
    inline void write(const Head &head, const Tail &...tail) {
        _write(head);
        if (space)
            _write(' ');
        write<ln, true>(tail...);
    }
    inline void flush() {
        fwrite(wtbuf, 1, wtRight, stdout);
        wtRight = 0;
    }
};

/**
 * @brief Fast IO
 */
#line 3 "sol.cpp"

#line 2 "library/Convolution/ntt.hpp"

template <typename T> struct NTT {
    static constexpr int rank2 = __builtin_ctzll(T::get_mod() - 1);
    std::array<T, rank2 + 1> root;  // root[i]^(2^i) == 1
    std::array<T, rank2 + 1> iroot; // root[i] * iroot[i] == 1

    std::array<T, std::max(0, rank2 - 2 + 1)> rate2;
    std::array<T, std::max(0, rank2 - 2 + 1)> irate2;

    std::array<T, std::max(0, rank2 - 3 + 1)> rate3;
    std::array<T, std::max(0, rank2 - 3 + 1)> irate3;

    NTT() {
        T g = 2;
        while (g.pow((T::get_mod() - 1) >> 1) == 1) {
            g += 1;
        }
        root[rank2] = g.pow((T::get_mod() - 1) >> rank2);
        iroot[rank2] = root[rank2].inv();
        for (int i = rank2 - 1; i >= 0; i--) {
            root[i] = root[i + 1] * root[i + 1];
            iroot[i] = iroot[i + 1] * iroot[i + 1];
        }

        {
            T prod = 1, iprod = 1;
            for (int i = 0; i <= rank2 - 2; i++) {
                rate2[i] = root[i + 2] * prod;
                irate2[i] = iroot[i + 2] * iprod;
                prod *= iroot[i + 2];
                iprod *= root[i + 2];
            }
        }
        {
            T prod = 1, iprod = 1;
            for (int i = 0; i <= rank2 - 3; i++) {
                rate3[i] = root[i + 3] * prod;
                irate3[i] = iroot[i + 3] * iprod;
                prod *= iroot[i + 3];
                iprod *= root[i + 3];
            }
        }
    }

    void ntt(std::vector<T> &a, bool type = 0) {
        int n = int(a.size());
        int h = __builtin_ctzll((unsigned int)n);

        if (type) {
            int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
            while (len) {
                if (len == 1) {
                    int p = 1 << (h - len);
                    T irot = 1;
                    for (int s = 0; s < (1 << (len - 1)); s++) {
                        int offset = s << (h - len + 1);
                        for (int i = 0; i < p; i++) {
                            auto l = a[i + offset];
                            auto r = a[i + offset + p];
                            a[i + offset] = l + r;
                            a[i + offset + p] =
                                (unsigned long long)(T::get_mod() + l.v - r.v) *
                                irot.v;
                            ;
                        }
                        if (s + 1 != (1 << (len - 1)))
                            irot *= irate2[__builtin_ctzll(~(unsigned int)(s))];
                    }
                    len--;
                } else {
                    // 4-base
                    int p = 1 << (h - len);
                    T irot = 1, iimag = iroot[2];
                    for (int s = 0; s < (1 << (len - 2)); s++) {
                        T irot2 = irot * irot;
                        T irot3 = irot2 * irot;
                        int offset = s << (h - len + 2);
                        for (int i = 0; i < p; i++) {
                            auto a0 = 1ULL * a[i + offset + 0 * p].v;
                            auto a1 = 1ULL * a[i + offset + 1 * p].v;
                            auto a2 = 1ULL * a[i + offset + 2 * p].v;
                            auto a3 = 1ULL * a[i + offset + 3 * p].v;

                            auto a2na3iimag =
                                1ULL * T((T::get_mod() + a2 - a3) * iimag.v).v;

                            a[i + offset] = a0 + a1 + a2 + a3;
                            a[i + offset + 1 * p] =
                                (a0 + (T::get_mod() - a1) + a2na3iimag) *
                                irot.v;
                            a[i + offset + 2 * p] =
                                (a0 + a1 + (T::get_mod() - a2) +
                                 (T::get_mod() - a3)) *
                                irot2.v;
                            a[i + offset + 3 * p] =
                                (a0 + (T::get_mod() - a1) +
                                 (T::get_mod() - a2na3iimag)) *
                                irot3.v;
                        }
                        if (s + 1 != (1 << (len - 2)))
                            irot *= irate3[__builtin_ctzll(~(unsigned int)(s))];
                    }
                    len -= 2;
                }
            }
            T e = T(n).inv();
            for (auto &x : a)
                x *= e;
        } else {
            int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
            while (len < h) {
                if (h - len == 1) {
                    int p = 1 << (h - len - 1);
                    T rot = 1;
                    for (int s = 0; s < (1 << len); s++) {
                        int offset = s << (h - len);
                        for (int i = 0; i < p; i++) {
                            auto l = a[i + offset];
                            auto r = a[i + offset + p] * rot;
                            a[i + offset] = l + r;
                            a[i + offset + p] = l - r;
                        }
                        if (s + 1 != (1 << len))
                            rot *= rate2[__builtin_ctzll(~(unsigned int)(s))];
                    }
                    len++;
                } else {
                    // 4-base
                    int p = 1 << (h - len - 2);
                    T rot = 1, imag = root[2];
                    for (int s = 0; s < (1 << len); s++) {
                        T rot2 = rot * rot;
                        T rot3 = rot2 * rot;
                        int offset = s << (h - len);
                        for (int i = 0; i < p; i++) {
                            auto mod2 = 1ULL * T::get_mod() * T::get_mod();
                            auto a0 = 1ULL * a[i + offset].v;
                            auto a1 = 1ULL * a[i + offset + p].v * rot.v;
                            auto a2 = 1ULL * a[i + offset + 2 * p].v * rot2.v;
                            auto a3 = 1ULL * a[i + offset + 3 * p].v * rot3.v;
                            auto a1na3imag =
                                1ULL * T(a1 + mod2 - a3).v * imag.v;
                            auto na2 = mod2 - a2;
                            a[i + offset] = a0 + a2 + a1 + a3;
                            a[i + offset + 1 * p] =
                                a0 + a2 + (2 * mod2 - (a1 + a3));
                            a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
                            a[i + offset + 3 * p] =
                                a0 + na2 + (mod2 - a1na3imag);
                        }
                        if (s + 1 != (1 << len))
                            rot *= rate3[__builtin_ctzll(~(unsigned int)(s))];
                    }
                    len += 2;
                }
            }
        }
    }
    vector<T> mult(const vector<T> &a, const vector<T> &b) {
        if (a.empty() or b.empty())
            return vector<T>();
        int as = a.size(), bs = b.size();
        int n = as + bs - 1;
        if (as <= 30 or bs <= 30) {
            if (as > 30)
                return mult(b, a);
            vector<T> res(n);
            rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j];
            return res;
        }
        int m = 1;
        while (m < n)
            m <<= 1;
        vector<T> res(m);
        rep(i, 0, as) res[i] = a[i];
        ntt(res);
        if (a == b)
            rep(i, 0, m) res[i] *= res[i];
        else {
            vector<T> c(m);
            rep(i, 0, bs) c[i] = b[i];
            ntt(c);
            rep(i, 0, m) res[i] *= c[i];
        }
        ntt(res, 1);
        res.resize(n);
        return res;
    }
};

/**
 * @brief Number Theoretic Transform
 */
#line 2 "library/FPS/fps.hpp"

template <typename T> struct Poly : vector<T> {
    Poly(int n = 0) { this->assign(n, T()); }
    Poly(const initializer_list<T> f) : vector<T>::vector(f) {}
    Poly(const vector<T> &f) { this->assign(ALL(f)); }
    T eval(const T &x) {
        T res;
        for (int i = this->size() - 1; i >= 0; i--)
            res *= x, res += this->at(i);
        return res;
    }
    Poly rev() const {
        Poly res = *this;
        reverse(ALL(res));
        return res;
    }
    void shrink() {
        while (!this->empty() and this->back() == 0)
            this->pop_back();
    }
    Poly operator>>(int sz) const {
        if ((int)this->size() <= sz)
            return {};
        Poly ret(*this);
        ret.erase(ret.begin(), ret.begin() + sz);
        return ret;
    }
    Poly operator<<(int sz) const {
        Poly ret(*this);
        ret.insert(ret.begin(), sz, T(0));
        return ret;
    }
    Poly<T> mult(const Poly<T> &a, const Poly<T> &b) {
        if (a.empty() or b.empty())
            return {};
        int as = a.size(), bs = b.size();
        int n = as + bs - 1;
        if (as <= 30 or bs <= 30) {
            if (as > 30)
                return mult(b, a);
            Poly<T> res(n);
            rep(i, 0, as) rep(j, 0, bs) res[i + j] += a[i] * b[j];
            return res;
        }
        int m = 1;
        while (m < n)
            m <<= 1;
        Poly<T> res(m);
        rep(i, 0, as) res[i] = a[i];
        NTT(res, 0);
        if (a == b)
            rep(i, 0, m) res[i] *= res[i];
        else {
            Poly<T> c(m);
            rep(i, 0, bs) c[i] = b[i];
            NTT(c, 0);
            rep(i, 0, m) res[i] *= c[i];
        }
        NTT(res, 1);
        res.resize(n);
        return res;
    }
    Poly square() const { return Poly(mult(*this, *this)); }
    Poly operator-() const { return Poly() - *this; }
    Poly operator+(const Poly &g) const { return Poly(*this) += g; }
    Poly operator+(const T &g) const { return Poly(*this) += g; }
    Poly operator-(const Poly &g) const { return Poly(*this) -= g; }
    Poly operator-(const T &g) const { return Poly(*this) -= g; }
    Poly operator*(const Poly &g) const { return Poly(*this) *= g; }
    Poly operator*(const T &g) const { return Poly(*this) *= g; }
    Poly operator/(const Poly &g) const { return Poly(*this) /= g; }
    Poly operator/(const T &g) const { return Poly(*this) /= g; }
    Poly operator%(const Poly &g) const { return Poly(*this) %= g; }
    pair<Poly, Poly> divmod(const Poly &g) const {
        Poly q = *this / g, r = *this - g * q;
        r.shrink();
        return {q, r};
    }
    Poly &operator+=(const Poly &g) {
        if (g.size() > this->size())
            this->resize(g.size());
        rep(i, 0, g.size()) { (*this)[i] += g[i]; }
        return *this;
    }
    Poly &operator+=(const T &g) {
        if (this->empty())
            this->push_back(0);
        (*this)[0] += g;
        return *this;
    }
    Poly &operator-=(const Poly &g) {
        if (g.size() > this->size())
            this->resize(g.size());
        rep(i, 0, g.size()) { (*this)[i] -= g[i]; }
        return *this;
    }
    Poly &operator-=(const T &g) {
        if (this->empty())
            this->push_back(0);
        (*this)[0] -= g;
        return *this;
    }
    Poly &operator*=(const Poly &g) {
        *this = mult(*this, g);
        return *this;
    }
    Poly &operator*=(const T &g) {
        rep(i, 0, this->size())(*this)[i] *= g;
        return *this;
    }
    Poly &operator/=(const Poly &g) {
        if (g.size() > this->size()) {
            this->clear();
            return *this;
        }
        Poly g2 = g;
        reverse(ALL(*this));
        reverse(ALL(g2));
        int n = this->size() - g2.size() + 1;
        this->resize(n);
        g2.resize(n);
        *this *= g2.inv();
        this->resize(n);
        reverse(ALL(*this));
        shrink();
        return *this;
    }
    Poly &operator/=(const T &g) {
        rep(i, 0, this->size())(*this)[i] /= g;
        return *this;
    }
    Poly &operator%=(const Poly &g) {
        *this -= *this / g * g;
        shrink();
        return *this;
    }
    Poly diff() const {
        Poly res(this->size() - 1);
        rep(i, 0, res.size()) res[i] = (*this)[i + 1] * (i + 1);
        return res;
    }
    Poly inte() const {
        Poly res(this->size() + 1);
        for (int i = res.size() - 1; i; i--)
            res[i] = (*this)[i - 1] / i;
        return res;
    }
    Poly log() const {
        assert(this->front() == 1);
        const int n = this->size();
        Poly res = diff() * inv();
        res = res.inte();
        res.resize(n);
        return res;
    }
    Poly shift(const int &c) const {
        const int n = this->size();
        Poly res = *this, g(n);
        g[0] = 1;
        rep(i, 1, n) g[i] = g[i - 1] * c / i;
        vector<T> fact(n, 1);
        rep(i, 0, n) {
            if (i)
                fact[i] = fact[i - 1] * i;
            res[i] *= fact[i];
        }
        res = res.rev();
        res *= g;
        res.resize(n);
        res = res.rev();
        rep(i, 0, n) res[i] /= fact[i];
        return res;
    }
    Poly inv() const {
        const int n = this->size();
        Poly res(1);
        res.front() = T(1) / this->front();
        for (int k = 1; k < n; k <<= 1) {
            Poly f(k * 2), g(k * 2);
            rep(i, 0, min(n, k * 2)) f[i] = (*this)[i];
            rep(i, 0, k) g[i] = res[i];
            NTT(f, 0);
            NTT(g, 0);
            rep(i, 0, k * 2) f[i] *= g[i];
            NTT(f, 1);
            rep(i, 0, k) {
                f[i] = 0;
                f[i + k] = -f[i + k];
            }
            NTT(f, 0);
            rep(i, 0, k * 2) f[i] *= g[i];
            NTT(f, 1);
            rep(i, 0, k) f[i] = res[i];
            swap(res, f);
        }
        res.resize(n);
        return res;
    }
    Poly exp() const {
        const int n = this->size();
        if (n == 1)
            return Poly({T(1)});
        Poly b(2), c(1), z1, z2(2);
        b[0] = c[0] = z2[0] = z2[1] = 1;
        b[1] = (*this)[1];
        for (int k = 2; k < n; k <<= 1) {
            Poly y = b;
            y.resize(k * 2);
            NTT(y, 0);
            z1 = z2;
            Poly z(k);
            rep(i, 0, k) z[i] = y[i] * z1[i];
            NTT(z, 1);
            rep(i, 0, k >> 1) z[i] = 0;
            NTT(z, 0);
            rep(i, 0, k) z[i] *= -z1[i];
            NTT(z, 1);
            c.insert(c.end(), z.begin() + (k >> 1), z.end());
            z2 = c;
            z2.resize(k * 2);
            NTT(z2, 0);
            Poly x = *this;
            x.resize(k);
            x = x.diff();
            x.resize(k);
            NTT(x, 0);
            rep(i, 0, k) x[i] *= y[i];
            NTT(x, 1);
            Poly bb = b.diff();
            rep(i, 0, k - 1) x[i] -= bb[i];
            x.resize(k * 2);
            rep(i, 0, k - 1) {
                x[k + i] = x[i];
                x[i] = 0;
            }
            NTT(x, 0);
            rep(i, 0, k * 2) x[i] *= z2[i];
            NTT(x, 1);
            x.pop_back();
            x = x.inte();
            rep(i, k, min(n, k * 2)) x[i] += (*this)[i];
            rep(i, 0, k) x[i] = 0;
            NTT(x, 0);
            rep(i, 0, k * 2) x[i] *= y[i];
            NTT(x, 1);
            b.insert(b.end(), x.begin() + k, x.end());
        }
        b.resize(n);
        return b;
    }
    Poly pow(ll t) {
        if (t == 0) {
            Poly res(this->size());
            res[0] = 1;
            return res;
        }
        int n = this->size(), k = 0;
        while (k < n and (*this)[k] == 0)
            k++;
        Poly res(n);
        if (__int128_t(t) * k >= n)
            return res;
        n -= t * k;
        Poly g(n);
        T c = (*this)[k], ic = c.inv();
        rep(i, 0, n) g[i] = (*this)[i + k] * ic;
        g = g.log();
        for (auto &x : g)
            x *= t;
        g = g.exp();
        c = c.pow(t);
        rep(i, 0, n) res[i + t * k] = g[i] * c;
        return res;
    }
    void NTT(vector<T> &a, bool inv) const;
};

/**
 * @brief Formal Power Series (NTT-friendly mod)
 */
#line 2 "library/Math/modint.hpp"

template <int mod = 1000000007> struct fp {
    int v;
    static constexpr int get_mod() { return mod; }
    int inv() const {
        int tmp, a = v, b = mod, x = 1, y = 0;
        while (b)
            tmp = a / b, a -= tmp * b, swap(a, b), x -= tmp * y, swap(x, y);
        if (x < 0) {
            x += mod;
        }
        return x;
    }
    fp(ll x = 0) : v(x >= 0 ? x % mod : (mod - (-x) % mod) % mod) {}
    fp operator-() const { return fp() - *this; }
    fp pow(ll t) {
        assert(t >= 0);
        fp res = 1, b = *this;
        while (t) {
            if (t & 1)
                res *= b;
            b *= b;
            t >>= 1;
        }
        return res;
    }
    fp &operator+=(const fp &x) {
        if ((v += x.v) >= mod)
            v -= mod;
        return *this;
    }
    fp &operator-=(const fp &x) {
        if ((v += mod - x.v) >= mod)
            v -= mod;
        return *this;
    }
    fp &operator*=(const fp &x) {
        v = ll(v) * x.v % mod;
        return *this;
    }
    fp &operator/=(const fp &x) {
        v = ll(v) * x.inv() % mod;
        return *this;
    }
    fp operator+(const fp &x) const { return fp(*this) += x; }
    fp operator-(const fp &x) const { return fp(*this) -= x; }
    fp operator*(const fp &x) const { return fp(*this) *= x; }
    fp operator/(const fp &x) const { return fp(*this) /= x; }
    bool operator==(const fp &x) const { return v == x.v; }
    bool operator!=(const fp &x) const { return v != x.v; }
    friend istream &operator>>(istream &is, fp &x) { return is >> x.v; }
    friend ostream &operator<<(ostream &os, const fp &x) { return os << x.v; }
};

template <typename T> T Inv(ll n) {
    static const int md = T::get_mod();
    static vector<T> buf({0, 1});
    assert(n > 0);
    n %= md;
    while (SZ(buf) <= n) {
        int k = SZ(buf), q = (md + k - 1) / k;
        buf.push_back(buf[k * q - md] * q);
    }
    return buf[n];
}

template <typename T> T Fact(ll n, bool inv = 0) {
    static const int md = T::get_mod();
    static vector<T> buf({1, 1}), ibuf({1, 1});
    assert(n >= 0 and n < md);
    while (SZ(buf) <= n) {
        buf.push_back(buf.back() * SZ(buf));
        ibuf.push_back(ibuf.back() * Inv<T>(SZ(ibuf)));
    }
    return inv ? ibuf[n] : buf[n];
}

template <typename T> T nPr(int n, int r, bool inv = 0) {
    if (n < 0 || n < r || r < 0)
        return 0;
    return Fact<T>(n, inv) * Fact<T>(n - r, inv ^ 1);
}
template <typename T> T nCr(int n, int r, bool inv = 0) {
    if (n < 0 || n < r || r < 0)
        return 0;
    return Fact<T>(n, inv) * Fact<T>(r, inv ^ 1) * Fact<T>(n - r, inv ^ 1);
}
template <typename T> T nHr(int n, int r, bool inv = 0) {
    return nCr<T>(n + r - 1, r, inv);
}

/**
 * @brief Modint
 */
#line 7 "sol.cpp"
using Fp = fp<998244353>;
NTT<Fp> ntt;
template <> void Poly<Fp>::NTT(vector<Fp> &v, bool inv) const {
    return ntt.ntt(v, inv);
}

#line 2 "library/FPS/nthterm.hpp"

template<typename T>T nth(Poly<T> p,Poly<T> q,ll n){
    while(n){
        Poly<T> base(q),np,nq;
        for(int i=1;i<(int)q.size();i+=2)base[i]=-base[i];
        p*=base; q*=base;
        for(int i=n&1;i<(int)p.size();i+=2)np.emplace_back(p[i]);
        for(int i=0;i<(int)q.size();i+=2)nq.emplace_back(q[i]);
        swap(p,np); swap(q,nq);
        n>>=1;
    }
    return p[0]/q[0];
}

/**
 * @brief Bostan-Mori Algorithm
 */
#line 2 "library/Math/matrix.hpp"

template<class T>struct Matrix{
    int h,w; vector<vector<T>> val; T det;
    Matrix(){}
    Matrix(int n):h(n),w(n),val(vector<vector<T>>(n,vector<T>(n))){}
    Matrix(int n,int m):h(n),w(m),val(vector<vector<T>>(n,vector<T>(m))){}
    vector<T>& operator[](const int i){return val[i];}
    Matrix& operator+=(const Matrix& m){
        assert(h==m.h and w==m.w);
        rep(i,0,h)rep(j,0,w)val[i][j]+=m.val[i][j];
        return *this;
    }
    Matrix& operator-=(const Matrix& m){
        assert(h==m.h and w==m.w);
        rep(i,0,h)rep(j,0,w)val[i][j]-=m.val[i][j];
        return *this;
    }
    Matrix& operator*=(const Matrix& m){
        assert(w==m.h);
        Matrix<T> res(h,m.w);
        rep(i,0,h)rep(j,0,m.w)rep(k,0,w)res.val[i][j]+=val[i][k]*m.val[k][j];
        *this=res; return *this;
    }
    Matrix operator+(const Matrix& m)const{return Matrix(*this)+=m;}
    Matrix operator-(const Matrix& m)const{return Matrix(*this)-=m;}
    Matrix operator*(const Matrix& m)const{return Matrix(*this)*=m;}
    Matrix pow(ll k){
        Matrix<T> res(h,h),c=*this; rep(i,0,h)res.val[i][i]=1;
        while(k){if(k&1)res*=c; c*=c; k>>=1;} return res;
    }
    vector<int> gauss(int c=-1){
        if(val.empty())return {};
        if(c==-1)c=w;
        int cur=0; vector<int> res; det=1;
        rep(i,0,c){
            if(cur==h)break;
            rep(j,cur,h)if(val[j][i]!=0){
                swap(val[cur],val[j]);
                if(cur!=j)det*=-1;
                break;
            }
            det*=val[cur][i];
            if(val[cur][i]==0)continue;
            rep(j,0,h)if(j!=cur){
                T z=val[j][i]/val[cur][i];
                rep(k,i,w)val[j][k]-=val[cur][k]*z;
            }
            res.push_back(i);
            cur++;
        }
        return res;
    }
    Matrix inv(){
        assert(h==w);
        Matrix base(h,h*2),res(h,h);
        rep(i,0,h)rep(j,0,h)base[i][j]=val[i][j];
        rep(i,0,h)base[i][h+i]=1;
        base.gauss(h);
        det=base.det;
        rep(i,0,h)rep(j,0,h)res[i][j]=base[i][h+j]/base[i][i];
        return res;
    }
    bool operator==(const Matrix& m){
        assert(h==m.h and w==m.w);
        rep(i,0,h)rep(j,0,w)if(val[i][j]!=m.val[i][j])return false;
        return true;
    }
    bool operator!=(const Matrix& m){
        assert(h==m.h and w==m.w);
        rep(i,0,h)rep(j,0,w)if(val[i][j]==m.val[i][j])return false;
        return true;
    }
    friend istream& operator>>(istream& is,Matrix& m){
        rep(i,0,m.h)rep(j,0,m.w)is>>m[i][j];
        return is;
    }
    friend ostream& operator<<(ostream& os,Matrix& m){
        rep(i,0,m.h){
            rep(j,0,m.w)os<<m[i][j]<<(j==m.w-1 and i!=m.h-1?'\n':' ');
        }
        return os;
    }
};

/**
 * @brief Matrix
 */
#line 15 "sol.cpp"

FastIO io;
int main() {
    int n, m, S, T;
    io.read(n, m, S, T);
    S--;
    T--;
    vector g(n, vector<int>());
    rep(_, 0, n - 1) {
        int u, v;
        io.read(u, v);
        u--;
        v--;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    vector<int> sz(n, 1);
    auto dfs1 = [&](auto &dfs1, int v, int p) -> void {
        int mx = -1;
        rep(i, 0, SZ(g[v])) {
            int to = g[v][i];
            if (to == p)
                continue;
            dfs1(dfs1, to, v);
            sz[v] += sz[to];
            if (chmax(mx, sz[to]))
                swap(g[v][i], g[v][0]);
        }
    };
    dfs1(dfs1, T, -1);

    using P = pair<Poly<Fp>, Poly<Fp>>;
    using Mat = Matrix<Poly<Fp>>;

    function<vector<P>(int, int)> rake;
    function<P(int, int)> compress;

    rake = [&](int v, int p) -> vector<P> {
        if (SZ(g[v]) == 1 and g[v][0] == p) {
            return {P{Poly<Fp>({0}), Poly<Fp>({1})}};
        }
        auto ret = rake(g[v][0], v);
        deque<P> deq;
        deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})});
        rep(i, 1, SZ(g[v])) if (g[v][i] != p) {
            deq.push_back(compress(g[v][i], v));
        }
        while (deq.size() > 1) {
            auto [A1, A2] = deq.front();
            deq.pop_front();
            auto [B1, B2] = deq.front();
            deq.pop_front();
            deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2});
        }
        ret.push_back(deq.front());
        // cerr << "rake: " << v << '\n';
        return ret;
    };
    compress = [&](int v, int p) -> P {
        auto fs = rake(v, p);
        reverse(ALL(fs));
        auto rec = [&](auto &rec, int L, int R) -> Mat {
            if (R - L == 1) {
                auto [f, g] = fs[L];
                Mat ret(2);
                ret[0][1] = g;
                ret[1][0] = -(g << 2);
                ret[1][1] = g - (g << 1) - (f << 2);
                return ret;
            }
            int mid = (L + R) >> 1;
            return rec(rec, L, mid) * rec(rec, mid, R);
        };
        auto A = rec(rec, 0, SZ(fs));
        auto f = A[0][1], g = A[1][1];
        f.shrink();
        g.shrink();
        // cerr << "compress: " << v << '\n';
        // for (auto &v : f)
        //     cerr << v.v << ' ';
        // cerr << '\n';
        // for (auto &v : g)
        //     cerr << v.v << ' ';
        // cerr << "\n\n";
        return P{f, g};
    };

    vector<P> fs;
    auto dfs2 = [&](auto &dfs2, int v, int p) -> bool {
        deque<P> deq;
        deq.push_back(P{Poly<Fp>({0}), Poly<Fp>({1})});
        int onedge = -1;
        rep(i, 0, SZ(g[v])) if (g[v][i] != p) {
            if (dfs2(dfs2, g[v][i], v)) {
                onedge = g[v][i];
            }
        }
        if (v == S)
            onedge = n;
        if (onedge != -1) {
            rep(i, 0, SZ(g[v])) if (g[v][i] != p) {
                if (onedge == n or g[v][i] != onedge) {
                    deq.push_back(compress(g[v][i], v));
                }
            }
            while (deq.size() > 1) {
                auto [A1, A2] = deq.front();
                deq.pop_front();
                auto [B1, B2] = deq.front();
                deq.pop_front();
                deq.push_back(P{A1 * B2 + A2 * B1, A2 * B2});
            }
            auto ret = deq.front();

            // cerr << "onedge: " << v << '\n';
            // for (auto &v : ret.first)
            //     cerr << v.v << ' ';
            // cerr << '\n';
            // for (auto &v : ret.second)
            //     cerr << v.v << ' ';
            // cerr << "\n\n";

            fs.push_back(ret);
        }
        return onedge != -1;
    };
    dfs2(dfs2, T, -1);
    reverse(ALL(fs));

    auto rec = [&](auto &rec, int L, int R) -> Mat {
        if (R - L == 1) {
            auto [f, g] = fs[L];
            Mat ret(2);
            ret[0][1] = g;
            ret[1][0] = -(g << 2);
            ret[1][1] = g - (g << 1) - (f << 2);
            return ret;
        }
        int mid = (L + R) >> 1;
        return rec(rec, L, mid) * rec(rec, mid, R);
    };
    auto A = rec(rec, 0, SZ(fs));
    Poly<Fp> num, den;
    den = A[1][1];
    den.shrink();

    deque<Poly<Fp>> deq;
    for (auto &[f, g] : fs)
        deq.push_back(g);
    while (deq.size() > 1) {
        auto A = deq.front();
        deq.pop_front();
        auto B = deq.front();
        deq.pop_front();
        deq.push_back(A * B);
    }
    num = deq.front();
    num.shrink();

    // for (auto &v : num)
    //     cerr << v.v << ' ';
    // cerr << '\n';
    // for (auto &v : den)
    //     cerr << v.v << ' ';
    // cerr << "\n\n";

    if (m < (SZ(fs) - 1)) {
        io.write(0);
        return 0;
    }
    Fp ret = nth(num, den, m - (SZ(fs) - 1));
    io.write(ret.v);
    return 0;
}
0