結果

問題 No.2485 Add to Variables (Another)
ユーザー rgnerdplayer
提出日時 2025-05-07 20:22:41
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 127 ms / 2,000 ms
コード長 12,499 bytes
コンパイル時間 3,542 ms
コンパイル使用メモリ 289,060 KB
実行使用メモリ 7,848 KB
最終ジャッジ日時 2025-05-07 20:22:48
合計ジャッジ時間 6,613 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 5
other AC * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

using i64 = long long;

template <int P>
struct ModInt {
    int v;

    constexpr ModInt() : v(0) {}
    constexpr ModInt(i64 v_) : v(v_ % P) {
        if (v < 0) {
            v += P;
        }
    }
    constexpr explicit operator int() const { return v; }
    constexpr friend ostream& operator<<(ostream &out, ModInt n) {
        return out << int(n);
    }
    constexpr friend istream& operator>>(istream &in, ModInt &n) {
        i64 v;
        in >> v;
        n = ModInt(v);
        return in;
    }

    constexpr friend bool operator==(ModInt a, ModInt b) {
        return a.v == b.v;
    }
    constexpr friend bool operator!=(ModInt a, ModInt b) {
        return a.v != b.v;
    }

    constexpr ModInt operator-() {
        ModInt res;
        res.v = v ? P - v : 0;
        return res;
    }

    constexpr ModInt& operator++() {
        v++;
        if (v == P) v = 0;
        return *this;
    }
    constexpr ModInt& operator--() {
        if (v == 0) v = P;
        v--;
        return *this;
    }
    constexpr ModInt& operator+=(ModInt o) {
        v -= P - o.v;
        v = (v < 0) ? v + P : v;
        return *this;
    }
    constexpr ModInt& operator-=(ModInt o) {
        v -= o.v;
        v = (v < 0) ? v + P : v;
        return *this;
    }
    constexpr ModInt& operator*=(ModInt o) {
        v = 1LL * v * o.v % P;
        return *this;
    }
    constexpr ModInt& operator/=(ModInt o) { return *this *= o.inv(); }

    constexpr friend ModInt operator++(ModInt &a, int) {
        ModInt r = a;
        ++a;
        return r;
    }
    constexpr friend ModInt operator--(ModInt &a, int) {
        ModInt r = a;
        --a;
        return r;
    }

    constexpr friend ModInt operator+(ModInt a, ModInt b) {
        return ModInt(a) += b;
    }
    constexpr friend ModInt operator-(ModInt a, ModInt b) {
        return ModInt(a) -= b;
    }
    constexpr friend ModInt operator*(ModInt a, ModInt b) {
        return ModInt(a) *= b;
    }
    constexpr friend ModInt operator/(ModInt a, ModInt b) {
        return ModInt(a) /= b;
    }

    constexpr ModInt qpow(i64 p) {
        ModInt res = 1, x = v;
        while (p > 0) {
            if (p & 1) { res *= x; }
            x *= x;
            p >>= 1;
        }
        return res;
    }
    constexpr ModInt inv() {
        return qpow(P - 2);
    }
};

constexpr int P = 998244353;
using Mint = ModInt<P>;

// < 0 return 0 ?
struct Combinatorial {
    int n;
    vector<Mint> _fact;
    vector<Mint> _ifact;
    vector<Mint> _inv;
    
    Combinatorial() : n{0}, _fact{1}, _ifact{1}, _inv{0} {}
    Combinatorial(int n) : Combinatorial() {
        init(n);
    }
    
    void init(int m) {
        if (m <= n) return;
        _fact.resize(m + 1);
        _ifact.resize(m + 1);
        _inv.resize(m + 1);
        
        for (int i = n + 1; i <= m; i++) {
            _fact[i] = _fact[i - 1] * i;
        }
        _ifact[m] = _fact[m].inv();
        for (int i = m; i > n; i--) {
            _ifact[i - 1] = _ifact[i] * i;
            _inv[i] = _ifact[i] * _fact[i - 1];
        }
        n = m;
    }
    
    Mint fact(int m) {
        if (m < 0) return 0;
        if (m > n) init(2 * m);
        return _fact[m];
    }
    Mint ifact(int m) {
        if (m < 0) return 0;
        if (m > n) init(2 * m);
        return _ifact[m];
    }
    Mint inv(int m) {
        if (m < 0) return 0;
        if (m > n) init(2 * m);
        return _inv[m];
    }
    Mint binom(int n, int m) {
        if (n < m || m < 0) return 0;
        return fact(n) * ifact(m) * ifact(n - m);
    }
} comb;

template<int P>
constexpr ModInt<P> findPrimitiveRoot() {
    ModInt<P> i = 2;
    int k = __builtin_ctz(P - 1);
    while (true) {
        if (i.qpow((P - 1) / 2) != 1) { break; }
        i = i + 1;
    }
    return i;
}

template<int P>
constexpr ModInt<P> primitiveRoot = findPrimitiveRoot<P>();

vector<int> rev;
template<int P>
vector<ModInt<P>> roots{0, 1};

template<int P>
void dft(vector<ModInt<P>> &a) {
    int n = a.size();
    if (n == 1) { return; }

    if (int(rev.size()) != n) {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; i++) {
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
        }
    }

    for (int i = 0; i < n; i++) {
        if (rev[i] < i) {
            swap(a[i], a[rev[i]]);
        }
    }
    if (int(roots<P>.size()) < n) {
        int k = __builtin_ctz(roots<P>.size());
        roots<P>.resize(n);
        while (1 << k < n) {
            auto e = ModInt<P>(primitiveRoot<P>).qpow(P - 1 >> k + 1);
            for (int i = 1 << k - 1; i < 1 << k; i++) {
                roots<P>[2 * i] = roots<P>[i];
                roots<P>[2 * i + 1] = roots<P>[i] * e;
            }
            k++;
        }
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; j++) {
                ModInt<P> u = a[i + j];
                ModInt<P> v = a[i + j + k] * roots<P>[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
        }
    }
}

template <int P>
void idft(vector<ModInt<P>> &a) {
    int n = a.size();
    reverse(a.begin() + 1, a.end());
    dft(a);
    ModInt<P> x = (1 - P) / n;
    for (int i = 0; i < n; i++) {
        a[i] *= x;
    }
}

template <int P>
vector<ModInt<P>> operator*(vector<ModInt<P>> a, vector<ModInt<P>> b) {
    if (a.empty() || b.empty()) { return {}; }
    int sz = 1, tot = a.size() + b.size() - 1;
    while (sz < tot) { sz *= 2; }
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    for (int i = 0; i < sz; i++) { a[i] *= b[i]; }
    idft(a);
    a.resize(tot);
    return a;
}

template <int P>
struct Poly : vector<ModInt<P>> {
    using Mint = ModInt<P>;
    using vector<Mint>::vector;
template <typename F>
    Poly(int n, F f) {
        this->resize(n);
        for (int i = 0; i < n; i++) {
            this->at(i) = f(i);
        }
    }
    Poly mulxk(int k) const {
        auto b = *this;
        b.insert(b.begin(), k, 0);
        return b;
    }
    Poly modxk(int k) const {
        k = min(k, int(this->size()));
        return Poly(this->begin(), this->begin() + k);
    }
    Poly divxk(int k) const {
        if (this->size() <= k) { return Poly(); }
        return Poly(this->begin() + k, this->end());
    }
    friend Poly operator+(const Poly &a, const Poly &b) {
        Poly res(max(a.size(), b.size()));
        for (int i = 0; i < int(a.size()); i++) { res[i] += a[i]; }
        for (int i = 0; i < int(b.size()); i++) { res[i] += b[i]; }
        return res;
    }
    friend Poly operator-(const Poly &a, const Poly &b) {
        Poly res(max(a.size(), b.size()));
        for (int i = 0; i < int(a.size()); i++) { res[i] += a[i]; }
        for (int i = 0; i < int(b.size()); i++) { res[i] -= b[i]; }
        return res;
    }
    friend Poly operator*(Poly a, Poly b) {
        if (a.empty() || b.empty()) { return Poly(); }
        int sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) { sz *= 2; }
        a.resize(sz);
        b.resize(sz);
        dft(a);
        dft(b);
        for (int i = 0; i < sz; i++) { a[i] *= b[i]; }
        idft(a);
        a.resize(tot);
        return a;
    }
    friend Poly operator*(Mint a, Poly b) {
        for (int i = 0; i < int(b.size()); i++) { b[i] *= a; }
        return b;
    }
    friend Poly operator*(Poly a, Mint b) {
        for (int i = 0; i < int(a.size()); i++) { a[i] *= b; }
        return a;
    }
    Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
    Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
    Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
    Poly derivative() {
        if (this->empty()) { return Poly(); }
        int n = this->size();
        Poly res(n - 1);
        for (int i = 0; i < n - 1; ++i) { res[i] = (i + 1) * (*this)[i + 1]; }
        return res;
    }
    Poly integral() {
        int n = this->size();
        Poly res(n + 1);
        for (int i = 0; i < n; ++i) { res[i + 1] = (*this)[i] / (i + 1); }
        return res;
    }
    Poly inv(int m) {
        // a[0] != 0
        Poly x({(*this)[0].inv()});
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x * (Poly({2}) - modxk(k) * x)).modxk(k);
        }
        return x.modxk(m);
    }
    Poly log(int m) {
        return (derivative() * inv(m)).integral().modxk(m);
    }
    Poly exp(int m) {
        Poly x({1});
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x * (Poly({1}) - x.log(k) + modxk(k))).modxk(k);
        }
        return x.modxk(m);
    }
    Poly pow(i64 k, int m) {
        if (k == 0) { return Poly(m, [&](int i) { return i == 0; }); }
        int i = 0;
        while (i < this->size() && (*this)[i] == 0) { i++; }
        if (i == this->size() || __int128(i) * k >= m) { return Poly(m); }
        Mint v = (*this)[i];
        auto f = divxk(i) * v.inv();
        return (f.log(m - i * k) * k).exp(m - i * k).mulxk(i * k) * v.qpow(k);
    }
    Poly sqrt(int m) {
        // a[0] == 1, otherwise quadratic residue?
        Poly x({1});
        int k = 1;
        while (k < m) {
            k *= 2;
            x = (x + (modxk(k) * x.inv(k)).modxk(k)) * ((P + 1) / 2);
        }
        return x.modxk(m);
    }
    Poly mulT(Poly b) {
        if (b.empty()) { return Poly(); }
        int n = b.size();
        reverse(b.begin(), b.end());
        return (*this * b).divxk(n - 1);
    }
    vector<Mint> evaluate(vector<Mint> x) {
        if (this->empty()) { return vector<Mint>(x.size()); }
        int n = max(x.size(), this->size());
        vector<Poly> q(4 * n);
        vector<Mint> ans(x.size());
        x.resize(n);
        auto build = [&](auto build, int id, int l, int r) -> void {
            if (r - l == 1) {
                q[id] = Poly({1, -x[l]});
            } else {
                int m = (l + r) / 2;
                build(build, 2 * id, l, m);
                build(build, 2 * id + 1, m, r);
                q[id] = q[2 * id] * q[2 * id + 1];
            }
        };
        build(build, 1, 0, n);
        auto work = [&](auto work, int id, int l, int r, Poly &num) -> void {
            if (r - l == 1) {
                if (l < int(ans.size())) {
                    ans[l] = num[0];
                }
            } else {
                int m = (l + r) / 2;
                work(work, 2 * id, l, m, num.mulT(q[2 * id + 1]).modxk(m - l));
                work(work, 2 * id + 1, m, r, num.mulT(q[2 * id]).modxk(r - m));
            }
        };
        work(work, 1, 0, n, mulT(q[1].inv(n)));
        return ans;
    }
};

template <int P>
Poly<P> interpolate(vector<ModInt<P>> x, vector<ModInt<P>> y) {
    // f(xi) = yi
    int n = x.size();
    vector<Poly<P>> p(4 * n), q(4 * n);
    auto dfs1 = [&](auto dfs1, int id, int l, int r) -> void {
        if (l == r) {
            p[id] = Poly<P>({-x[l], 1});
            return;
        }
        int m = l + r >> 1;
        dfs1(dfs1, id << 1, l, m);
        dfs1(dfs1, id << 1 | 1, m + 1, r);
        p[id] = p[id << 1] * p[id << 1 | 1];
    };
    dfs1(dfs1, 1, 0, n - 1);
    Poly<P> f = Poly<P>(p[1].derivative().evaluate(x));
    auto dfs2 = [&](auto dfs2, int id, int l, int r) -> void {
        if (l == r) {
            q[id] = Poly<P>({y[l] / f[l]});
            return;
        }
        int m = l + r >> 1;
        dfs2(dfs2, id << 1, l, m);
        dfs2(dfs2, id << 1 | 1, m + 1, r);
        q[id] = q[id << 1] * p[id << 1 | 1] + q[id << 1 | 1] * p[id << 1];
    };
    dfs2(dfs2, 1, 0, n - 1);
    return q[1];
}

using FPS = Poly<P>;

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);

    auto solve = [&]() {
        int n, m;
        cin >> n >> m;

        vector<int> a(n), b(n - 1);
        for (int i = 0; i < n; i++) {
            cin >> a[i];
            if (i > 0) {
                b[i - 1] = a[i] - a[i - 1];
            }
        }

        FPS dp(m + 1);
        dp[0] = 1;

        for (auto d : b) {
            dp = (dp * FPS(m + 1, [&](int i) {
                return comb.ifact(i) * comb.ifact(i + d);
            })).modxk(m + 1);
        }

        Mint ans = 0;

        for (int i = 0; i <= a[0]; i++) {
            ans += dp[i] * comb.ifact(a[0] - i) * comb.ifact(m - a[n - 1] - i);
        }

        ans *= comb.fact(m);

        cout << ans << '\n';
    };
    
    solve();
    
    return 0;
}
0