結果

問題 No.2345 max(l,r)
ユーザー chineristAC
提出日時 2023-06-09 23:41:25
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,699 bytes
コンパイル時間 282 ms
コンパイル使用メモリ 82,228 KB
実行使用メモリ 124,464 KB
最終ジャッジ日時 2025-01-02 05:51:19
合計ジャッジ時間 11,874 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 14 WA * 54
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

import random

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod


mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

for _ in range(int(input())):
    N,M = mi()
    A = li()
    A.sort()

    if N & 1 or A[0]!=N//2:
        k = N//2
        if k < A[0]:
            print(0)
            continue

        small = [a for a in A if a <= k]
        if small[0]!=small[-1]:
            print(0)
            continue
        big = [a for a in A if k < a]
        dic = {a:0 for a in big}
        for a in big:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            #print(tmp_l,tmp_r,t)
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t or tmp_r == t:
                c += 1
                L += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        c += 1
        res = res * cmb(M,c,mod) % mod
        if max(L,R) == small[0]:
            print(res)
        else:
            print(0)
        continue
    else:
        dic = {a:0 for a in A}
        for a in A:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t or tmp_r == t:
                c += 1
                L += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        res = res * cmb(M,c,mod) % mod
        if L == N//2 and R == N//2:
            print(res)
        else:
            print(0)
        continue


            
                





0