結果

問題 No.2345 max(l,r)
ユーザー chineristACchineristAC
提出日時 2023-06-09 23:49:31
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 229 ms / 2,000 ms
コード長 3,659 bytes
コンパイル時間 388 ms
コンパイル使用メモリ 82,124 KB
実行使用メモリ 124,112 KB
最終ジャッジ日時 2024-06-10 14:54:01
合計ジャッジ時間 11,569 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 54 ms
65,788 KB
testcase_01 AC 190 ms
83,668 KB
testcase_02 AC 129 ms
82,852 KB
testcase_03 AC 128 ms
82,832 KB
testcase_04 AC 126 ms
82,832 KB
testcase_05 AC 126 ms
82,896 KB
testcase_06 AC 126 ms
82,956 KB
testcase_07 AC 127 ms
82,896 KB
testcase_08 AC 131 ms
82,824 KB
testcase_09 AC 137 ms
83,116 KB
testcase_10 AC 129 ms
82,796 KB
testcase_11 AC 136 ms
82,988 KB
testcase_12 AC 185 ms
83,368 KB
testcase_13 AC 164 ms
83,444 KB
testcase_14 AC 160 ms
83,572 KB
testcase_15 AC 163 ms
83,640 KB
testcase_16 AC 160 ms
83,448 KB
testcase_17 AC 162 ms
83,380 KB
testcase_18 AC 165 ms
83,404 KB
testcase_19 AC 163 ms
83,572 KB
testcase_20 AC 162 ms
83,440 KB
testcase_21 AC 164 ms
83,432 KB
testcase_22 AC 113 ms
105,388 KB
testcase_23 AC 106 ms
105,720 KB
testcase_24 AC 68 ms
79,452 KB
testcase_25 AC 80 ms
81,164 KB
testcase_26 AC 80 ms
80,976 KB
testcase_27 AC 77 ms
81,512 KB
testcase_28 AC 75 ms
97,092 KB
testcase_29 AC 103 ms
99,844 KB
testcase_30 AC 73 ms
83,320 KB
testcase_31 AC 121 ms
109,520 KB
testcase_32 AC 81 ms
93,432 KB
testcase_33 AC 99 ms
101,304 KB
testcase_34 AC 54 ms
67,176 KB
testcase_35 AC 64 ms
80,352 KB
testcase_36 AC 75 ms
87,792 KB
testcase_37 AC 84 ms
93,936 KB
testcase_38 AC 142 ms
117,988 KB
testcase_39 AC 168 ms
118,588 KB
testcase_40 AC 155 ms
117,996 KB
testcase_41 AC 142 ms
118,068 KB
testcase_42 AC 142 ms
118,300 KB
testcase_43 AC 143 ms
118,348 KB
testcase_44 AC 152 ms
117,932 KB
testcase_45 AC 169 ms
124,112 KB
testcase_46 AC 169 ms
122,044 KB
testcase_47 AC 141 ms
117,880 KB
testcase_48 AC 102 ms
82,044 KB
testcase_49 AC 99 ms
81,972 KB
testcase_50 AC 105 ms
81,864 KB
testcase_51 AC 101 ms
81,916 KB
testcase_52 AC 103 ms
82,020 KB
testcase_53 AC 104 ms
82,132 KB
testcase_54 AC 103 ms
81,992 KB
testcase_55 AC 98 ms
81,716 KB
testcase_56 AC 108 ms
81,992 KB
testcase_57 AC 115 ms
82,032 KB
testcase_58 AC 139 ms
82,196 KB
testcase_59 AC 128 ms
82,204 KB
testcase_60 AC 132 ms
82,344 KB
testcase_61 AC 129 ms
81,908 KB
testcase_62 AC 138 ms
82,104 KB
testcase_63 AC 134 ms
82,304 KB
testcase_64 AC 131 ms
81,924 KB
testcase_65 AC 133 ms
82,012 KB
testcase_66 AC 147 ms
82,056 KB
testcase_67 AC 128 ms
82,112 KB
testcase_68 AC 229 ms
83,988 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from itertools import permutations

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

import random

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod


mod = 998244353
N = 2*10**5
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

def brute(N,M,A):
    tmp = [-1] * N
    def dfs(i):
        if i == N:
            for j in range(N):
                s,b = 0,0
                for k in range(N):
                    if tmp[k] < tmp[j]:
                        s+= 1
                    elif tmp[j] < tmp[k]:
                        b += 1
                if max(s,b)!=A[j]:
                    return 0
            #print(tmp)
            return 1
        
        res = 0
        for a in range(1,M+1):
            tmp[i] = a
            res += dfs(i+1)
            tmp[i] = -1
        return res
    return dfs(0)


            

def solve(N,M,A):
    A.sort()

    if N & 1 or A[0]!=N//2:
        k = N//2
        if k < A[0]:
            return 0

        small = [a for a in A if a <= k]
        if small[0]!=small[-1]:
            return 0
        
        big = [a for a in A if k < a]
        dic = {a:0 for a in big}
        for a in big:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            #print(tmp_l,tmp_r,t)
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        c += 1
        res = res * cmb(M,c,mod) % mod
        if max(L,R) == small[0]:
            return res
        else:
            return 0
        
    else:
        dic = {a:0 for a in A}
        for a in A:
            dic[a] += 1
        big = [(a,dic[a]) for a in dic]
        big.sort()

        L,R = 0,0
        c = 0
        res = 1
        for a,t in big[::-1]:
            tmp_l = N-a-L
            tmp_r = N-a-R
            if tmp_l == t and tmp_r == t:
                c += 1
                res = 2 * res % mod
                L += t
            elif tmp_l == t:
                c += 1
                L += t
            elif tmp_r == t:
                c += 1
                R += t
            elif tmp_l+tmp_r == t and 0 < tmp_l and 0 < tmp_r:
                c += 2
                L += tmp_l
                R += tmp_r
                res = res * cmb(t,tmp_l,mod)
            else:
                res = 0
        
        res = res * cmb(M,c,mod) % mod
        if L == N//2 and R == N//2:
            return res
        else:
            return 0
        


            
while False:
    N,M = random.randint(4,5),random.randint(3,10)
    A = [random.randint(0,N-1) for i in range(N)]
    
    if solve(N,M,A) != brute(N,M,A):
        print(N,M,A)
        print(solve(N,M,A),brute(N,M,A))
        break
    print("OK",N)





for _ in range(int(input())):
    N,M = mi()
    A = li()
    A.sort()
    print(solve(N,M,A))
0