結果

問題 No.3540 Arise
コンテスト
ユーザー nauclhlt
提出日時 2026-03-07 01:23:33
言語 C++17
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++17 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 8,478 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,707 ms
コンパイル使用メモリ 227,940 KB
実行使用メモリ 7,976 KB
最終ジャッジ日時 2026-05-08 20:54:27
合計ジャッジ時間 12,207 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge3_0
このコードへのチャレンジ
(要ログイン)
サブタスク 配点 結果
サブタスク1 30 % AC * 19
サブタスク2 70 % AC * 14 RE * 8
合計 3.5 * 30% = 105 点
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

#define CONST_MOD 998244353LL
// #define CONST_MOD 1000000007L
struct ModInt
{
    long long Value;

public:
    ModInt()
    {
        Value = 0L;
    }

    ModInt(long long value)
    {
        Value = value;
    }

    ModInt Power(long long exp) const
    {
        if (exp <= -1L)
        {
            return ModInt(1L) / Power(-exp);
        }
        if (exp == 0L)
            return 1L;
        if (exp == 1L)
            return *this;

        ModInt m = Power(exp / 2L);
        m = m * m;
        if (exp % 2L == 1L)
        {
            m = m * (*this);
        }

        return m;
    }

    ModInt Inv() const
    {
        return this->Power(CONST_MOD - 2L);
    }

    ModInt operator+() const
    {
        return *this;
    }

    ModInt operator-() const
    {
        return ModInt(-Value);
    }

    friend ModInt operator+(const ModInt& left, const ModInt& right)
    {
        return ModInt(SafeMod(left.Value + right.Value));
    }

    friend ModInt operator+(const ModInt& left, const long long& right)
    {
        return ModInt(SafeMod(left.Value + right));
    }

    friend ModInt operator+(const long long& left, const ModInt& right)
    {
        return ModInt(SafeMod(left + right.Value));
    }

    ModInt& operator+=(const ModInt& x)
    {
        Value += x.Value;
        Value = SafeMod(Value);

        return *this;
    }

    ModInt& operator+=(const long long& x)
    {
        Value += x;
        Value = SafeMod(Value);

        return *this;
    }

    friend ModInt operator-(const ModInt& left, const ModInt& right)
    {
        return ModInt(SafeMod(left.Value - right.Value));
    }

    friend ModInt operator-(const ModInt& left, const long long& right)
    {
        return ModInt(SafeMod(left.Value - right));
    }

    friend ModInt operator-(const long long& left, const ModInt& right)
    {
        return ModInt(SafeMod(left - right.Value));
    }

    ModInt& operator-=(const ModInt& x)
    {
        Value -= x.Value;
        Value = SafeMod(Value);

        return *this;
    }

    ModInt& operator-=(const long long& x)
    {
        Value -= x;
        Value = SafeMod(Value);

        return *this;
    }

    friend ModInt operator*(const ModInt& left, const ModInt& right)
    {
        return ModInt(SafeMod(left.Value * right.Value));
    }

    friend ModInt operator*(const ModInt& left, const long long& right)
    {
        return ModInt(SafeMod(left.Value * right));
    }

    friend ModInt operator*(const long long& left, const ModInt& right)
    {
        return ModInt(SafeMod(left * right.Value));
    }

    ModInt& operator*=(const ModInt& x)
    {
        Value *= x.Value;
        Value = SafeMod(Value);

        return *this;
    }

    ModInt& operator*=(const long long& x)
    {
        Value *= x;
        Value = SafeMod(Value);

        return *this;
    }

    friend ModInt operator /(const ModInt& left, const ModInt& right)
    {
        ModInt inv = right.Inv();
        return ModInt(SafeMod(left.Value * inv.Value));
    }

    friend ModInt operator/(const ModInt& left, const long long& right)
    {
        return ModInt(SafeMod(left.Value * ModInt(right).Inv().Value));
    }

    friend ModInt operator/(const long long& left, const ModInt& right)
    {
        return ModInt(SafeMod(left * right.Inv().Value));
    }

    ModInt& operator/=(const ModInt& x)
    {
        Value *= x.Inv().Value;
        Value = SafeMod(Value);

        return *this;
    }

    ModInt& operator/=(const long long& x)
    {
        Value *= ModInt(x).Inv().Value;
        Value = SafeMod(Value);

        return *this;
    }

    ModInt& operator++()
    {
        ++Value;
        Value = SafeMod(Value);
        return *this;
    }

    ModInt operator++(int)
    {
        ModInt temp = *this;
        Value++;
        Value = SafeMod(Value);
        return temp;
    }

    ModInt& operator--()
    {
        --Value;
        Value = SafeMod(Value);
        return *this;
    }

    ModInt operator--(int)
    {
        ModInt temp = *this;
        Value--;
        Value = SafeMod(Value);
        return temp;
    }

    inline static ModInt One()
    {
        return ModInt(1L);
    }

    static ModInt Combination(long long n, long long r)
    {
        ModInt c = 1L;
        for (ModInt i = 1; i.Value <= r; i++)
        {
            c = c * (ModInt(n) - i + ModInt::One()) / i;
        }
        return c;
    }

private:
    inline static long long SafeMod(long long a)
    {
        a %= CONST_MOD;
        if (a < 0)
        {
            a += CONST_MOD;
        }
        return a;
    }
};

class ModCache
{
private:
    vector<ModInt> _factorial;
    vector<ModInt> _inverseFactorial;
    vector<ModInt> _inverse;

public:
    ModCache(int max)
    {
        _factorial.resize(max + 1);
        _inverseFactorial.resize(max + 1);
        _inverse.resize(max + 1);

        _factorial[0] = 1;
        _inverseFactorial[0] = 1LL;
        _inverse[1] = 1LL;

        for (long p = 1; p <= max; p++)
        {
            _factorial[p] = _factorial[p - 1] * p;
            if (p > 1)
            {
                _inverse[p] = -(CONST_MOD / p) * _inverse[CONST_MOD % p];
            }
            _inverseFactorial[p] = _inverseFactorial[p - 1] * _inverse[p];
        }
    }

    ModInt Combination(int n, int r)
    {
        if (n < 0 || r < 0 || n < r) return 0;
        return _factorial[n] * (_inverseFactorial[n - r] * _inverseFactorial[r]);
    }

    ModInt Permutation(int n, int r)
    {
        return _factorial[n] * _inverseFactorial[n - r];
    }

    ModInt Factorial(int n)
    {
        return _factorial[n];
    }

    ModInt InverseFactorial(int n)
    {
        return _inverseFactorial[n];
    }

    ModInt Inverse(int n)
    {
        return _inverse[n];
    }
};


ModInt solveNaive(int N, vector<ll> A, vector<ModInt> invA)
{
    ModInt ans = 0LL;
    for (int i = 0; i < N; i++)
    {
        ans += (A[i] + 1) / (ModInt)2LL;
    }

    for (int m = 1; m <= A[0]; m++)
    {
        ModInt left = 1LL;
        ModInt right = 1LL;
        for (int i = 1; i < N; i++)
        {
            right *= (ModInt)(A[i] - m) * invA[i];
        }

        for (int k = 0; k < N; k++)
        {
            ans += (A[k] - m) * left * right *  invA[k];
            left *= (A[k] - m + 1) * invA[k];
            if (k < N - 1)
            {
                right /= (ModInt)(A[k + 1] - m) * invA[k + 1];
            }
        }
    }

    return ans;
}



int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    cin >> N;
    vector<ll> A(N);
    for (int i = 0; i < N; i++)
    {
        cin >> A[i];
    }

    vector<ModInt> invA(N);
    for (int i = 0; i < N; i++)
    {
        invA[i] = ((ModInt)A[i]).Inv();
    }

    sort(A.begin(), A.end());
    for (int i = 0; i < N; i++)
    {
        assert(1 <= A[i] && A[i] < 998244353);
        if (i < N - 1)
        {
            assert(A[i] != A[i + 1]);
        }
    }

    ModInt ans = 0LL;
    for (int i = 0; i < N; i++)
    {
        ans += (A[i] + 1) / (ModInt)2LL;
    }

    if (A[0] <= N + 2)
    {
        cout << solveNaive(N, A, invA).Value << endl;
        return 0;
    }

    vector<ModInt> F(N + 3);
    
    for (int x = 1; x <= N + 2; x++)
    {
        ModInt left = 1LL;
        ModInt right = 1LL;
        for (int i = 1; i < N; i++)
        {
            right *= (ModInt)(A[i] - x) * invA[i];
        }

        for (int k = 0; k < N; k++)
        {
            F[x] += (A[k] - x) * left * right * invA[k];
            left *= (A[k] - x + 1) * invA[k];
            if (k < N - 1)
            {
                right /= (ModInt)(A[k + 1] - x) * invA[k + 1];
            }
        }
    }
    
    for (int j = 2; j <= N + 2; j++)
    {
        F[j] += F[j - 1];
    }

    ModCache cache(N + 100);

    vector<ModInt> prefix(N + 3);
    vector<ModInt> suffix(N + 3);

    prefix[0] = 1LL;
    suffix[N + 2] = 1LL;
    for (int i = 1; i <= N + 2; i++)
    {
        prefix[i] = prefix[i - 1] * (A[0] - i);
    }

    for (int i = N + 1; i >= 0; i--)
    {
        suffix[i] = suffix[i + 1] * (A[0] - i - 1);
    }

    for (int i = 1; i <= N + 2; i++)
    {
        long sign = (i - N + 2) % 2 == 0 ? 1 : -1;
        ans += F[i] * (prefix[i - 1] * suffix[i]) * (cache.InverseFactorial(i - 1) * sign * cache.InverseFactorial(N + 2 - i));
    }

    cout << ans.Value << endl;
}
0