結果

問題 No.2345 max(l,r)
ユーザー chineristACchineristAC
提出日時 2023-06-09 23:49:31
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 229 ms / 2,000 ms
コード長 3,659 bytes
コンパイル時間 537 ms
コンパイル使用メモリ 82,252 KB
実行使用メモリ 124,412 KB
最終ジャッジ日時 2025-01-02 06:04:58
合計ジャッジ時間 11,620 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 68
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

import random

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod


mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

def brute(N,M,A):
    tmp = [-1] * N
    def dfs(i):
        if i == N:
            for j in range(N):
                s,b = 0,0
                for k in range(N):
                    if tmp[k] < tmp[j]:
                        s+= 1
                    elif tmp[j] < tmp[k]:
                        b += 1
                if max(s,b)!=A[j]:
                    return 0
            #print(tmp)
            return 1
        
        res = 0
        for a in range(1,M+1):
            tmp[i] = a
            res += dfs(i+1)
            tmp[i] = -1
        return res
    return dfs(0)


            

def solve(N,M,A):
    A.sort()

    if N & 1 or A[0]!=N//2:
        k = N//2
        if k < A[0]:
            return 0

        small = [a for a in A if a <= k]
        if small[0]!=small[-1]:
            return 0
        
        big = [a for a in A if k < a]
        dic = {a:0 for a in big}
        for a in big:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            #print(tmp_l,tmp_r,t)
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        c += 1
        res = res * cmb(M,c,mod) % mod
        if max(L,R) == small[0]:
            return res
        else:
            return 0
        
    else:
        dic = {a:0 for a in A}
        for a in A:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        res = res * cmb(M,c,mod) % mod
        if L == N//2 and R == N//2:
            return res
        else:
            return 0
        


            
while False:
    N,M = random.randint(4,5),random.randint(3,10)
    A = [random.randint(0,N-1) for i in range(N)]
    
    if solve(N,M,A) != brute(N,M,A):
        print(N,M,A)
        print(solve(N,M,A),brute(N,M,A))
        break
    print("OK",N)





for _ in range(int(input())):
    N,M = mi()
    A = li()
    A.sort()
    print(solve(N,M,A))
0