結果

問題 No.1693 Invasion
ユーザー lyululyulu
提出日時 2021-10-01 22:35:14
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 1,727 bytes
コンパイル時間 771 ms
コンパイル使用メモリ 64,296 KB
実行使用メモリ 53,292 KB
最終ジャッジ日時 2023-09-26 19:07:34
合計ジャッジ時間 4,486 ms
ジャッジサーバーID
(参考情報)
judge15 / judge13
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
8,752 KB
testcase_01 AC 2 ms
4,376 KB
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 TLE -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <stdio.h>
#include <algorithm>
#include <vector>
#include <set>

long long inv(long long a, long long mod){
    long long s = 1, tmp = a, t = mod - 2;
    while(t > 0){
        if(t % 2 == 1){
            s *= tmp;
            s %= mod;
        }
        t /= 2;
        tmp *= tmp;
        tmp %= mod;
    }
    return s;
}

int main() {
    int n, m; scanf("%d %d", &n, &m);
    std::vector<long long> v(n);
    for(int i = 0; i < n; ++i) {
        scanf("%ld", &v[i]);
    }
    long long ans = 1, MOD = 998244353;
    std::vector<std::set<int>> s(n+1); s[0].insert(0);
    std::vector<long long> fac(m+1); fac[0] = 1;
    for(int i = 1; i <= m; ++i) {
        fac[i] = fac[i-1] * i % MOD;
    }
    for(int i = 1; i <= n; ++i) {
        for(int tmp: s[i-1]) {
            for(int j = 0; j < n; ++j) {
                if(tmp + v[j] <= m) s[i].insert(tmp + v[j]);
            }
        }
        for(int tmp: s[i]) {
            ans += fac[m-i] * inv(fac[tmp-i], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD;
            ans %= MOD;
            // printf("%ld ", fac[m-i] * inv(fac[tmp-i], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD);
        }
        // printf("%ld\n", ans);
    }
    for(int i = 1; i < n; ++i) {
        for(int j = i+1; j <= n; ++j) {
            std::set<int> mul;
            std::set_intersection(s[i].begin(), s[i].end(),
                                  s[j].begin(), s[j].end(),
                                  std::inserter(mul, mul.end()));
            for(int tmp: mul) {
                ans += MOD;
                ans -= fac[m-j] * inv(fac[tmp-j], MOD) % MOD * inv(fac[m-tmp], MOD) % MOD;
                ans %= MOD;
            }
        }
    }
    printf("%ld\n", ans);
    return 0;
}
0