結果

問題 No.1691 Badugi
ユーザー chineristACchineristAC
提出日時 2021-09-24 23:16:18
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 138 ms / 2,000 ms
コード長 3,182 bytes
コンパイル時間 518 ms
コンパイル使用メモリ 87,260 KB
実行使用メモリ 88,664 KB
最終ジャッジ日時 2023-09-18 22:20:54
合計ジャッジ時間 4,332 ms
ジャッジサーバーID
(参考情報)
judge15 / judge13
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 138 ms
88,632 KB
testcase_01 AC 132 ms
88,420 KB
testcase_02 AC 130 ms
88,264 KB
testcase_03 AC 129 ms
88,308 KB
testcase_04 AC 134 ms
88,272 KB
testcase_05 AC 130 ms
88,148 KB
testcase_06 AC 129 ms
88,276 KB
testcase_07 AC 131 ms
88,272 KB
testcase_08 AC 129 ms
88,136 KB
testcase_09 AC 128 ms
88,612 KB
testcase_10 AC 131 ms
88,660 KB
testcase_11 AC 128 ms
88,312 KB
testcase_12 AC 130 ms
88,548 KB
testcase_13 AC 128 ms
88,520 KB
testcase_14 AC 130 ms
88,496 KB
testcase_15 AC 126 ms
88,300 KB
testcase_16 AC 129 ms
88,664 KB
testcase_17 AC 129 ms
88,632 KB
testcase_18 AC 131 ms
88,428 KB
testcase_19 AC 128 ms
88,560 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys,random

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

def cmb(n, r, mod):
    if ( r<0 or r>n ):
        return 0
    return (g1[n] * g2[r] % mod) * g2[n-r] % mod

mod = 998244353
N = 5*10**5 + 100
g1 = [1]*(N+1)
g2 = [1]*(N+1)
inverse = [1]*(N+1)

for i in range( 2, N + 1 ):
    g1[i]=( ( g1[i-1] * i ) % mod )
    inverse[i]=( ( -inverse[mod % i] * (mod//i) ) % mod )
    g2[i]=( (g2[i-1] * inverse[i]) % mod )
inverse[0]=0

N,M,K = mi()

res = 0

"""
行-列-行
    |
    行

行... K-3 + 3
列... K-3 + 1
"""

res += cmb(N,K,mod) * cmb(K,3,mod) *  cmb(M,K-2,mod) * (K-2) * g1[K-3]
res %= mod

"""
行-列-行
    |
    行-列

行... K-4 + 3
列... K-4 + 2
"""

if 4<=K:
    res += cmb(N,K-1,mod) * cmb(K-1,3,mod) * 3 *  cmb(M,K-2,mod) * cmb(K-2,2,mod) * 2 * g1[K-4]
    res %= mod

"""
行-列-行-列
    |
    行-列

行... K-5 + 3
列... K-5 + 3
"""

if 5 <= K:
    res += cmb(N,K-2,mod) * cmb(K-2,3,mod) * 3 * cmb(M,K-2,mod) * cmb(K-2,3,mod) * 6 * g1[K-5]
    res %= mod


"""
列-行-列
    |
    列

行... K-3 + 1
列... K-3 + 3
"""

res += cmb(N,K-2,mod) * (K-2) * cmb(M,K,mod) * cmb(K,3,mod) * g1[K-3]
res %= mod

"""
列-行-列
    |
    列-行

列... K-4 + 3
行... K-4 + 2
"""

if 4<=K:
    res += cmb(M,K-1,mod) * cmb(K-1,3,mod) * 3 *  cmb(N,K-2,mod) * cmb(K-2,2,mod) * 2 * g1[K-4]
    res %= mod

"""
列-行-列-行
    |
    列-行

列... K-5 + 3
行... K-5 + 3
"""

if 5 <= K:
    res += cmb(N,K-2,mod) * cmb(K-2,3,mod) * 3 * cmb(M,K-2,mod) * cmb(K-2,3,mod) * 6 * g1[K-5]
    #print(cmb(N,K-2,mod) * cmb(K-2,3,mod) * 3 * cmb(M,K-2,mod) * cmb(K-2,3,mod) * 6 * g1[K-5])
    res %= mod



"""
・列-行-列 1 *a
・行-列-行 1 *b
・行-列-行-列 2 *c

a+b+c=2
行-列 * (K-2-a-b-2*c)
"""

coef = [1,1,3]
coef2 = [1,2,12]
coef3 = [1,1,6]
for a in range(3):
    for b in range(3):
        for c in range(3):
            if a+b+c!=2 or (K-2-a-b-2*c) < 0 or (2*a+2*b+3*c+K-2-a-b-2*c)!=K:
                continue
            
            gyo = a + 2 * b + 2 * c + (K-2-a-b-2*c)
            ret = 2*a + b + 2 * c + (K-2-a-b-2*c)

            tmp_g = cmb(N,gyo,mod) * cmb(gyo,a,mod) * cmb(gyo-a,2*b,mod) * coef[b] * cmb(gyo-a-2*b,2*c,mod) * coef2[c]
            tmp_r = cmb(M,ret,mod) * cmb(ret,2*a,mod) * coef3[a] * cmb(ret-2*a,b,mod) * g1[b] * cmb(ret-2*a-b,2*c,mod) * g1[2*c] * g1[ret-2*a-b-2*c]
            res += tmp_g * tmp_r % mod

            res %= mod

"""
行-列-行-列-行
列-行-列-行-列

行-列
| |
列-行

行-列 (K-4)
"""

if 4 <= K:
    tmp_1 = cmb(N,K-1,mod) * cmb(K-1,3,mod) * 3 * cmb(M,K-2,mod) * cmb(K-2,2,mod) * 2 * g1[K-4]
    tmp_2 = cmb(N,K-2,mod) * cmb(K-2,2,mod) * cmb(M,K-1,mod) * cmb(K-1,3,mod) * 6 * g1[K-4]
    tmp_3 = cmb(N,K-2,mod) * cmb(K-2,2,mod) * cmb(M,K-2,mod) * cmb(K-2,2,mod) * g1[K-4]
    res += tmp_1 + tmp_2 + tmp_3
    res %= mod

"""
行-列-行-列-行-列

行-列 K-5
"""

if 5 <= K:
    res += cmb(N,K-2,mod) * cmb(K-2,3,mod) * g1[3] * cmb(M,K-2,mod) * cmb(K-2,3,mod) * g1[3] * g1[K-5]
    #print(cmb(N,K-2,mod) * cmb(K-2,3,mod) * g1[3] * cmb(M,K-2,mod) * cmb(K-2,3,mod) * g1[3] * g1[K-5])
    res %= mod

print(res)
0