結果

問題 No.2345 max(l,r)
ユーザー chineristACchineristAC
提出日時 2023-06-09 23:49:31
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 333 ms / 2,000 ms
コード長 3,659 bytes
コンパイル時間 1,552 ms
コンパイル使用メモリ 86,588 KB
実行使用メモリ 136,340 KB
最終ジャッジ日時 2023-08-30 14:57:28
合計ジャッジ時間 17,585 ms
ジャッジサーバーID
(参考情報)
judge14 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 109 ms
81,516 KB
testcase_01 AC 276 ms
84,876 KB
testcase_02 AC 196 ms
84,580 KB
testcase_03 AC 195 ms
84,092 KB
testcase_04 AC 196 ms
84,280 KB
testcase_05 AC 193 ms
84,372 KB
testcase_06 AC 194 ms
84,528 KB
testcase_07 AC 194 ms
84,620 KB
testcase_08 AC 196 ms
84,464 KB
testcase_09 AC 195 ms
84,136 KB
testcase_10 AC 195 ms
84,528 KB
testcase_11 AC 192 ms
84,520 KB
testcase_12 AC 236 ms
84,892 KB
testcase_13 AC 239 ms
84,948 KB
testcase_14 AC 240 ms
85,044 KB
testcase_15 AC 236 ms
84,724 KB
testcase_16 AC 235 ms
84,588 KB
testcase_17 AC 235 ms
85,160 KB
testcase_18 AC 237 ms
84,620 KB
testcase_19 AC 242 ms
84,696 KB
testcase_20 AC 236 ms
84,544 KB
testcase_21 AC 232 ms
84,568 KB
testcase_22 AC 167 ms
108,468 KB
testcase_23 AC 168 ms
109,088 KB
testcase_24 AC 127 ms
88,156 KB
testcase_25 AC 135 ms
82,744 KB
testcase_26 AC 136 ms
82,916 KB
testcase_27 AC 133 ms
82,852 KB
testcase_28 AC 134 ms
104,228 KB
testcase_29 AC 160 ms
108,636 KB
testcase_30 AC 129 ms
89,484 KB
testcase_31 AC 179 ms
122,160 KB
testcase_32 AC 138 ms
97,060 KB
testcase_33 AC 151 ms
102,756 KB
testcase_34 AC 108 ms
81,368 KB
testcase_35 AC 121 ms
90,216 KB
testcase_36 AC 133 ms
91,296 KB
testcase_37 AC 140 ms
97,160 KB
testcase_38 AC 193 ms
129,084 KB
testcase_39 AC 210 ms
130,272 KB
testcase_40 AC 202 ms
129,128 KB
testcase_41 AC 195 ms
129,460 KB
testcase_42 AC 198 ms
129,316 KB
testcase_43 AC 194 ms
129,312 KB
testcase_44 AC 210 ms
129,520 KB
testcase_45 AC 231 ms
136,340 KB
testcase_46 AC 223 ms
133,744 KB
testcase_47 AC 195 ms
129,076 KB
testcase_48 AC 162 ms
83,316 KB
testcase_49 AC 160 ms
83,084 KB
testcase_50 AC 165 ms
83,548 KB
testcase_51 AC 165 ms
83,344 KB
testcase_52 AC 165 ms
83,340 KB
testcase_53 AC 166 ms
83,356 KB
testcase_54 AC 168 ms
83,424 KB
testcase_55 AC 161 ms
83,312 KB
testcase_56 AC 167 ms
83,388 KB
testcase_57 AC 169 ms
84,028 KB
testcase_58 AC 199 ms
83,408 KB
testcase_59 AC 197 ms
83,764 KB
testcase_60 AC 196 ms
83,580 KB
testcase_61 AC 198 ms
83,984 KB
testcase_62 AC 194 ms
83,420 KB
testcase_63 AC 196 ms
83,696 KB
testcase_64 AC 199 ms
84,072 KB
testcase_65 AC 196 ms
83,760 KB
testcase_66 AC 197 ms
83,876 KB
testcase_67 AC 192 ms
83,836 KB
testcase_68 AC 333 ms
85,552 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

import random

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod


mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

def brute(N,M,A):
    tmp = [-1] * N
    def dfs(i):
        if i == N:
            for j in range(N):
                s,b = 0,0
                for k in range(N):
                    if tmp[k] < tmp[j]:
                        s+= 1
                    elif tmp[j] < tmp[k]:
                        b += 1
                if max(s,b)!=A[j]:
                    return 0
            #print(tmp)
            return 1
        
        res = 0
        for a in range(1,M+1):
            tmp[i] = a
            res += dfs(i+1)
            tmp[i] = -1
        return res
    return dfs(0)


            

def solve(N,M,A):
    A.sort()

    if N & 1 or A[0]!=N//2:
        k = N//2
        if k < A[0]:
            return 0

        small = [a for a in A if a <= k]
        if small[0]!=small[-1]:
            return 0
        
        big = [a for a in A if k < a]
        dic = {a:0 for a in big}
        for a in big:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            #print(tmp_l,tmp_r,t)
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        c += 1
        res = res * cmb(M,c,mod) % mod
        if max(L,R) == small[0]:
            return res
        else:
            return 0
        
    else:
        dic = {a:0 for a in A}
        for a in A:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        res = res * cmb(M,c,mod) % mod
        if L == N//2 and R == N//2:
            return res
        else:
            return 0
        


            
while False:
    N,M = random.randint(4,5),random.randint(3,10)
    A = [random.randint(0,N-1) for i in range(N)]
    
    if solve(N,M,A) != brute(N,M,A):
        print(N,M,A)
        print(solve(N,M,A),brute(N,M,A))
        break
    print("OK",N)





for _ in range(int(input())):
    N,M = mi()
    A = li()
    A.sort()
    print(solve(N,M,A))
0