結果

問題 No.1693 Invasion
ユーザー lyulu
提出日時 2021-10-01 22:35:14
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 1,727 bytes
コンパイル時間 686 ms
コンパイル使用メモリ 65,076 KB
実行使用メモリ 58,148 KB
最終ジャッジ日時 2024-07-19 13:03:02
合計ジャッジ時間 4,478 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2 WA * 1
other WA * 6 TLE * 1 -- * 14
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <stdio.h>
#include <algorithm>
#include <vector>
#include <set>

long long inv(long long a, long long mod){
    long long s = 1, tmp = a, t = mod - 2;
    while(t > 0){
        if(t % 2 == 1){
            s *= tmp;
            s %= mod;
        }
        t /= 2;
        tmp *= tmp;
        tmp %= mod;
    }
    return s;
}

int main() {
    int n, m; scanf("%d %d", &n, &m);
    std::vector<long long> v(n);
    for(int i = 0; i < n; ++i) {
        scanf("%ld", &v[i]);
    }
    long long ans = 1, MOD = 998244353;
    std::vector<std::set<int>> s(n+1); s[0].insert(0);
    std::vector<long long> fac(m+1); fac[0] = 1;
    for(int i = 1; i <= m; ++i) {
        fac[i] = fac[i-1] * i % MOD;
    }
    for(int i = 1; i <= n; ++i) {
        for(int tmp: s[i-1]) {
            for(int j = 0; j < n; ++j) {
                if(tmp + v[j] <= m) s[i].insert(tmp + v[j]);
            }
        }
        for(int tmp: s[i]) {
            ans += fac[m-i] * inv(fac[tmp-i], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD;
            ans %= MOD;
            // printf("%ld ", fac[m-i] * inv(fac[tmp-i], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD);
        }
        // printf("%ld\n", ans);
    }
    for(int i = 1; i < n; ++i) {
        for(int j = i+1; j <= n; ++j) {
            std::set<int> mul;
            std::set_intersection(s[i].begin(), s[i].end(),
                                  s[j].begin(), s[j].end(),
                                  std::inserter(mul, mul.end()));
            for(int tmp: mul) {
                ans += MOD;
                ans -= fac[m-j] * inv(fac[tmp-j], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD;
                ans %= MOD;
            }
        }
    }
    printf("%ld\n", ans);
    return 0;
}
0