結果

問題 No.3540 Arise
コンテスト
ユーザー nauclhlt
提出日時 2026-03-07 00:57:24
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
WA  
(最新)
AC  
(最初)
実行時間 -
コード長 2,702 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 398 ms
コンパイル使用メモリ 85,376 KB
実行使用メモリ 71,296 KB
最終ジャッジ日時 2026-05-08 20:53:50
合計ジャッジ時間 22,401 ms
ジャッジサーバーID
(参考情報)
judge3_1 / judge2_0
このコードへのチャレンジ
(要ログイン)
サブタスク 配点 結果
サブタスク1 30 % AC * 19
サブタスク2 70 % AC * 21 WA * 1
合計 3.5 * 30% = 105 点
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys

input = sys.stdin.readline

def solve():
    MOD = 998244353
    try:
        line1 = input().split()
        if not line1: return
        N = int(line1[0])
        A = list(map(int, input().split()))
    except ValueError:
        return

    A.sort()
    
    inv2 = pow(2, MOD - 2, MOD)
    ans = 0
    for x in A:
        ans = (ans + (x + 1) * inv2) % MOD

    if A[0] <= N + 2:
        naive_sum = 0
        for m in range(1, A[0] + 1):
            inv_A = [pow(x, MOD - 2, MOD) for x in A]

            right = 1
            for i in range(1, N):
                right = right * (A[i] - m) % MOD * inv_A[i] % MOD
            
            left = 1
            for k in range(N):
                term = (A[k] - m) * left % MOD * right % MOD * inv_A[k] % MOD
                naive_sum = (naive_sum + term) % MOD
                
                left = left * (A[k] - m + 1) % MOD * inv_A[k] % MOD
                if k < N - 1:
                    right = right * A[k+1] % MOD * pow(A[k+1] - m, MOD - 2, MOD) % MOD
        print((ans + naive_sum) % MOD)
        return
    
    limit = N + 3
    fact_inv = [1] * limit
    fact = [1] * limit
    for i in range(1, limit):
        fact[i] = (fact[i - 1] * i) % MOD
    fact_inv[limit - 1] = pow(fact[limit - 1], MOD - 2, MOD)
    for i in range(limit - 2, -1, -1):
        fact_inv[i] = (fact_inv[i + 1] * (i + 1)) % MOD

    G_sums = [0] * limit
    inv_A = [pow(x, MOD - 2, MOD) for x in A]

    for x in range(1, limit):
        right = 1
        for i in range(1, N):
            right = right * (A[i] - x) % MOD * inv_A[i] % MOD
        
        left = 1
        current_f_sum = 0
        for k in range(N):
            term = (A[k] - x) * left % MOD * right % MOD * inv_A[k] % MOD
            current_f_sum = (current_f_sum + term) % MOD
            
            left = left * (A[k] - x + 1) % MOD * inv_A[k] % MOD
            if k < N - 1:
                right = right * A[k+1] % MOD * pow(A[k+1] - x, MOD - 2, MOD) % MOD
        
        G_sums[x] = (G_sums[x-1] + current_f_sum) % MOD

    target = A[0]
    
    pre = [1] * limit
    suf = [1] * limit
    for i in range(1, limit):
        pre[i] = pre[i - 1] * (target - i) % MOD
    for i in range(limit - 1, 0, -1):
        suf[i - 1] = suf[i] * (target - i) % MOD

    interp_ans = 0
    for i in range(1, limit):
        num = pre[i - 1] * suf[i] % MOD
        
        den_inv = fact_inv[i - 1] * fact_inv[limit - 1 - i] % MOD
        if (limit - 1 - i) % 2 == 1:
            den_inv = (MOD - den_inv) % MOD
            
        interp_ans = (interp_ans + G_sums[i] * num % MOD * den_inv) % MOD

    print((ans + interp_ans) % MOD)

if __name__ == '__main__':
    solve()
0