結果
| 問題 |
No.2345 max(l,r)
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-06-09 23:49:31 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 229 ms / 2,000 ms |
| コード長 | 3,659 bytes |
| コンパイル時間 | 537 ms |
| コンパイル使用メモリ | 82,252 KB |
| 実行使用メモリ | 124,412 KB |
| 最終ジャッジ日時 | 2025-01-02 06:04:58 |
| 合計ジャッジ時間 | 11,620 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 68 |
ソースコード
import sys
from itertools import permutations
input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())
import random
def cmb(n, r, mod):
if ( r<0 or r>n ):
return 0
return (g1[n] * g2[r] % mod) * g2[n-r] % mod
mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)
for i in range( 2, N + 1 ):
g1[i]=( ( g1[i-1] * i ) % mod )
inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0
def brute(N,M,A):
tmp = [-1] * N
def dfs(i):
if i == N:
for j in range(N):
s,b = 0,0
for k in range(N):
if tmp[k] < tmp[j]:
s+= 1
elif tmp[j] < tmp[k]:
b += 1
if max(s,b)!=A[j]:
return 0
#print(tmp)
return 1
res = 0
for a in range(1,M+1):
tmp[i] = a
res += dfs(i+1)
tmp[i] = -1
return res
return dfs(0)
def solve(N,M,A):
A.sort()
if N & 1 or A[0]!=N//2:
k = N//2
if k < A[0]:
return 0
small = [a for a in A if a <= k]
if small[0]!=small[-1]:
return 0
big = [a for a in A if k < a]
dic = {a:0 for a in big}
for a in big:
dic[a] += 1
big = [(a,dic[a]) for a in dic]
big.sort()
L,R = 0,0
c = 0
res = 1
for a,t in big[::-1]:
tmp_l = N-a-L
tmp_r = N-a-R
#print(tmp_l,tmp_r,t)
if tmp_l == t and tmp_r == t:
c += 1
res = 2 * res % mod
L += t
elif tmp_l == t:
c += 1
L += t
elif tmp_r == t:
c += 1
R += t
elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
c += 2
L += tmp_l
R += tmp_r
res = res * cmb(t,tmp_l,mod)
else:
res = 0
c += 1
res = res * cmb(M,c,mod) % mod
if max(L,R) == small[0]:
return res
else:
return 0
else:
dic = {a:0 for a in A}
for a in A:
dic[a] += 1
big = [(a,dic[a]) for a in dic]
big.sort()
L,R = 0,0
c = 0
res = 1
for a,t in big[::-1]:
tmp_l = N-a-L
tmp_r = N-a-R
if tmp_l == t and tmp_r == t:
c += 1
res = 2 * res % mod
L += t
elif tmp_l == t:
c += 1
L += t
elif tmp_r == t:
c += 1
R += t
elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
c += 2
L += tmp_l
R += tmp_r
res = res * cmb(t,tmp_l,mod)
else:
res = 0
res = res * cmb(M,c,mod) % mod
if L == N//2 and R == N//2:
return res
else:
return 0
while False:
N,M = random.randint(4,5),random.randint(3,10)
A = [random.randint(0,N-1) for i in range(N)]
if solve(N,M,A) != brute(N,M,A):
print(N,M,A)
print(solve(N,M,A),brute(N,M,A))
break
print("OK",N)
for _ in range(int(input())):
N,M = mi()
A = li()
A.sort()
print(solve(N,M,A))