結果
| 問題 | No.3540 Arise |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-05-08 21:39:29 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 174 ms / 2,500 ms |
| コード長 | 4,850 bytes |
| 記録 | |
| コンパイル時間 | 197 ms |
| コンパイル使用メモリ | 85,376 KB |
| 実行使用メモリ | 82,176 KB |
| 最終ジャッジ日時 | 2026-05-08 21:39:38 |
| 合計ジャッジ時間 | 6,233 ms |
|
ジャッジサーバーID (参考情報) |
judge1_1 / judge3_0 |
(要ログイン)
| サブタスク | 配点 | 結果 |
|---|---|---|
| サブタスク1 | 30 % | AC * 19 |
| サブタスク2 | 70 % | AC * 22 |
| 合計 | 3.5 * 100% = 350 点 |
ソースコード
# input
import sys
input = sys.stdin.readline
II = lambda : int(input())
MI = lambda : map(int, input().split())
LI = lambda : [int(a) for a in input().split()]
SI = lambda : input().rstrip()
LLI = lambda n : [[int(a) for a in input().split()] for _ in range(n)]
LSI = lambda n : [input().rstrip() for _ in range(n)]
MI_1 = lambda : map(lambda x:int(x)-1, input().split())
LI_1 = lambda : [int(a)-1 for a in input().split()]
mod = 998244353
inf = 1001001001001001001
ordalp = lambda s : ord(s)-65 if s.isupper() else ord(s)-97
ordallalp = lambda s : ord(s)-39 if s.isupper() else ord(s)-97
yes = lambda : print("Yes")
no = lambda : print("No")
yn = lambda flag : print("Yes" if flag else "No")
prinf = lambda ans : print(ans if ans < 1000001001001001001 else -1)
alplow = "abcdefghijklmnopqrstuvwxyz"
alpup = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
alpall = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
URDL = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}
DIR_4 = [[-1,0],[0,1],[1,0],[0,-1]]
DIR_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]]
DIR_BISHOP = [[-1,1],[1,1],[1,-1],[-1,-1]]
prime60 = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59]
sys.set_int_max_str_digits(0)
# sys.setrecursionlimit(10**6)
# import pypyjit
# pypyjit.set_param('max_unroll_recursion=-1')
from collections import defaultdict,deque
from heapq import heappop,heappush
from bisect import bisect_left,bisect_right
DD = defaultdict
BSL = bisect_left
BSR = bisect_right
class Comb:
__slots__ = ["fac", "finv", "mod"]
def __init__(self, lim:int, mod:int = mod):
"""
mod : prime
"""
self.fac = [1]*(lim+1)
self.finv = [1]*(lim+1)
self.mod = mod
for i in range(2,lim+1):
self.fac[i] = self.fac[i-1]*i%self.mod
self.finv[lim] = pow(self.fac[lim],-1,mod)
for i in range(lim,2,-1):
self.finv[i-1] = self.finv[i]*i%self.mod
def C(self, a, b):
if b < 0 or a < b: return 0
if a < 0: return 0
return self.fac[a]*self.finv[b]%self.mod*self.finv[a-b]%self.mod
def __call__(self, a, b):
if b < 0 or a < b: return 0
if a < 0: return 0
return self.fac[a]*self.finv[b]%self.mod*self.finv[a-b]%self.mod
def P(self, a, b):
if b < 0 or a < b: return 0
if a < 0: return 0
return self.fac[a]*self.finv[a-b]%self.mod
def M(self, *k):
n = sum(k)
if n < 0: return 0
res = self.fac[n]
for ki in k:
if ki < 0: return 0
res = res * self.finv[ki] % self.mod
return res
def H(self, a, b): return self.C(a+b-1,b)
def F(self, a): return self.fac[a]
def Fi(self, a): return self.finv[a]
n = II()
a = LI()
"""
問題文読めない
もっとも小さい出目をいい感じのに変更すると言われている?
変更するものを j として その値を v とすると式が立つ
これを適当に解けばよい?
x のとき
(a_i - x) を求めておく?
ある v につうての期待値を求めておけばそれの補完で可能ですか ok
"""
def lagrange(f:list, t:int, mod:int = mod):
"""
k次多項式の
f(0)~f(k)を与えられた時
f(t)を求める
"""
k = len(f) - 1
if t <= k:
return f[t]
top1 = [1]*(k+1)
top2 = [1]*(k+1)
for i in range(k):
top1[i+1] = top1[i]*(t-i)%mod
for i in range(k,0,-1):
top2[i-1] = top2[i]*(i-t)%mod
finv = [0]*(k+1)
inv = 1
for i in range(2,k+1):
inv *= i
inv %= mod
finv[k] = pow(inv,-1,mod)
for i in range(k,0,-1):
finv[i-1] = finv[i]*i%mod
res = 0
for i in range(k+1):
tmp = f[i]*top1[i]%mod*top2[i]%mod*finv[i]%mod*finv[k-i]%mod
res += tmp
res %= mod
return res
comb = Comb(4000)
a.sort()
mi = a[0] # これ以下ではある
def calc(x):
# あるスコア x をとる確率
r = [1] * (n + 1)
for i in reversed(range(n)):
r[i] = r[i + 1] * (a[i] - x) % mod
l = [1] * (n + 1)
for i in range(n):
l[i+1] = l[i] * (a[i] - x + 1) % mod
s = 0
for i in range(n):
s = (s + l[i] * r[i]) % mod
return s
# そもそもの期待値
inv2 = mod + 1 >> 1
tmp = 0
for i in range(n):
tmp += (a[i] + 1) * inv2 % mod
tmp %= mod
# 全体の分母
inv = 1
for i in range(n):
inv = inv * a[i] % mod
inv = pow(inv, -1, mod)
if a[0] <= n + 1:
t = 0
for i in range(1, a[0] + 1):
t += calc(i)
t %= mod
ans = (tmp + t * inv) % mod
print(ans)
exit()
# そうではないときは多項式補完
# n 次式ではある?
ss = [0] * (n + 2)
for i in range(n + 1):
ss[i+1] = (ss[i] + calc(i + 1)) % mod
ans = lagrange(ss, a[0])
print((tmp + ans * inv) % mod)