結果

問題 No.2345 max(l,r)
ユーザー chineristACchineristAC
提出日時 2023-06-09 23:41:25
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,699 bytes
コンパイル時間 426 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 123,924 KB
最終ジャッジ日時 2024-06-10 14:45:52
合計ジャッジ時間 11,577 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 58 ms
66,288 KB
testcase_01 WA -
testcase_02 AC 132 ms
83,192 KB
testcase_03 AC 144 ms
83,332 KB
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 AC 186 ms
84,148 KB
testcase_13 AC 189 ms
84,132 KB
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 AC 107 ms
105,420 KB
testcase_23 AC 114 ms
107,328 KB
testcase_24 AC 69 ms
78,680 KB
testcase_25 AC 77 ms
81,412 KB
testcase_26 AC 79 ms
80,968 KB
testcase_27 AC 82 ms
81,164 KB
testcase_28 AC 79 ms
98,492 KB
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 WA -
testcase_35 AC 62 ms
80,940 KB
testcase_36 AC 71 ms
86,376 KB
testcase_37 WA -
testcase_38 WA -
testcase_39 WA -
testcase_40 WA -
testcase_41 WA -
testcase_42 WA -
testcase_43 WA -
testcase_44 WA -
testcase_45 WA -
testcase_46 WA -
testcase_47 WA -
testcase_48 WA -
testcase_49 WA -
testcase_50 WA -
testcase_51 WA -
testcase_52 WA -
testcase_53 WA -
testcase_54 WA -
testcase_55 WA -
testcase_56 WA -
testcase_57 WA -
testcase_58 WA -
testcase_59 WA -
testcase_60 WA -
testcase_61 WA -
testcase_62 WA -
testcase_63 WA -
testcase_64 WA -
testcase_65 WA -
testcase_66 WA -
testcase_67 WA -
testcase_68 AC 241 ms
84,040 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

import random

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod


mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

for _ in range(int(input())):
    N,M = mi()
    A = li()
    A.sort()

    if N & 1 or A[0]!=N//2:
        k = N//2
        if k < A[0]:
            print(0)
            continue

        small = [a for a in A if a <= k]
        if small[0]!=small[-1]:
            print(0)
            continue
        big = [a for a in A if k < a]
        dic = {a:0 for a in big}
        for a in big:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            #print(tmp_l,tmp_r,t)
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t or tmp_r == t:
                c += 1
                L += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        c += 1
        res = res * cmb(M,c,mod) % mod
        if max(L,R) == small[0]:
            print(res)
        else:
            print(0)
        continue
    else:
        dic = {a:0 for a in A}
        for a in A:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t or tmp_r == t:
                c += 1
                L += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        res = res * cmb(M,c,mod) % mod
        if L == N//2 and R == N//2:
            print(res)
        else:
            print(0)
        continue


            
                





0