結果

問題 No.1856 Mex Sum 2
ユーザー lam6er
提出日時 2025-03-26 15:59:17
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,390 bytes
コンパイル時間 197 ms
コンパイル使用メモリ 82,168 KB
実行使用メモリ 66,836 KB
最終ジャッジ日時 2025-03-26 16:00:20
合計ジャッジ時間 7,678 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 26 TLE * 1 -- * 37
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

n, M = map(int, input().split())

max_k = min(M + 1, n)

# Precompute combinations (n choose m)
comb = [[0] * (n + 1) for _ in range(n + 1)]
comb[0][0] = 1
for i in range(1, n + 1):
    comb[i][0] = 1
    for j in range(1, i + 1):
        comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % MOD

ans = 0

for k in range(1, max_k + 1):
    if k > M:
        valid = 0
    else:
        valid = 1
    if k <= M:
        allowed = M - k  # number of elements > k-1 and !=k (which are k+1..M)
        elements = k + allowed  # 0..k-1 and k+1..M
    else:
        elements = M + 1  # k = M+1, so allowed elements are 0..M

    # Compute sum_terms for each m from 1 to n
    total = 0
    for m in range(1, n + 1):
        if k <= M:
            if elements == 0:
                continue
            # sum_{t=0}^k (-1)^t * C(k, t) * ( (k - t) + allowed )^m
            sum_terms = 0
            for t in range(0, k+1):
                sign = (-1) ** t
                c = comb[k][t]
                base = (k - t) + allowed
                term = pow(base, m, MOD)
                term = term * c % MOD
                if sign == -1:
                    term = (-term) % MOD
                sum_terms = (sum_terms + term) % MOD
        else:
            # k == M+1, elements = M+1 (0..M)
            # sum_{t=0}^{k} -> sum_{t=0}^{M+1} (-1)^t * C(M+1, t) * (M+1 - t)^m
            # but k = M+1 could be larger than n, but since max_k = min(M+1, n), k <= n
            sum_terms = 0
            for t in range(0, k + 1):
                if t > M + 1:
                    continue
                sign = (-1) ** t
                c = comb[k][t] if t <= k else 0
                base = (M + 1 - t) % MOD
                term = pow(base, m, MOD)
                term = term * c % MOD
                if sign == -1:
                    term = (-term) % MOD
                sum_terms = (sum_terms + term) % MOD

        # Compute the term for this m
        term = comb[n][m] * sum_terms % MOD
        if k <= M:
            # other elements can be anything (including k)
            # so (M+1)^(n - m)
            term = term * pow(M + 1, n - m, MOD) % MOD
        else:
            # k == M+1, other elements can be anything (but k is not present)
            term = term * pow(M + 1, n - m, MOD) % MOD

        total = (total + term) % MOD

    ans = (ans + k * total) % MOD

print(ans)
0