結果

問題 No.840 ほむほむほむら
ユーザー 双六双六
提出日時 2020-07-15 22:26:47
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 981 ms / 4,000 ms
コード長 1,446 bytes
コンパイル時間 342 ms
コンパイル使用メモリ 86,612 KB
実行使用メモリ 78,432 KB
最終ジャッジ日時 2023-08-14 16:14:03
合計ジャッジ時間 10,177 ms
ジャッジサーバーID
(参考情報)
judge15 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 104 ms
77,036 KB
testcase_01 AC 110 ms
77,056 KB
testcase_02 AC 129 ms
77,408 KB
testcase_03 AC 225 ms
77,656 KB
testcase_04 AC 108 ms
77,292 KB
testcase_05 AC 107 ms
77,104 KB
testcase_06 AC 111 ms
77,332 KB
testcase_07 AC 146 ms
77,320 KB
testcase_08 AC 331 ms
77,620 KB
testcase_09 AC 111 ms
77,324 KB
testcase_10 AC 112 ms
77,288 KB
testcase_11 AC 116 ms
77,308 KB
testcase_12 AC 169 ms
77,304 KB
testcase_13 AC 670 ms
78,020 KB
testcase_14 AC 183 ms
77,396 KB
testcase_15 AC 109 ms
77,156 KB
testcase_16 AC 118 ms
77,468 KB
testcase_17 AC 243 ms
77,780 KB
testcase_18 AC 827 ms
78,424 KB
testcase_19 AC 981 ms
78,376 KB
testcase_20 AC 105 ms
76,864 KB
testcase_21 AC 107 ms
77,188 KB
testcase_22 AC 117 ms
77,412 KB
testcase_23 AC 938 ms
78,324 KB
testcase_24 AC 120 ms
77,468 KB
testcase_25 AC 108 ms
77,060 KB
testcase_26 AC 126 ms
77,336 KB
testcase_27 AC 942 ms
78,432 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys; input = sys.stdin.buffer.readline
sys.setrecursionlimit(10**7)
from collections import defaultdict
mod = 998244353

def getlist():
	return list(map(int, input().split()))

#A*B
def mul(A, B):
	C = [[0] * len(B[0]) for i in range(len(A))]
	for i in range(len(A)):
		for k in range(len(B)):
			for j in range(len(B[0])):
				C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod
	return C

#A**n 繰り返し二乗法の要領で計算する N:行列の縦横の大きさ
def matrixPow(A, n, N):
	#B:単位行列 演算の種類によって初期化方法を変える必要もある
	B = [[0] * N for i in range(N)]
	for i in range(N):
		B[i][i] = 1

	while n > 0:
		if n & 1 == 1:
			B = mul(A, B)
		A = mul(A, A)
		n = n >> 1

	return B

#処理内容
def main():
	N, K = getlist()
	Det = [[0] * (K ** 3) for i in range(K ** 3)]
	for x in range(K):
		for y in range(K):
			for z in range(K):
				itr = x + y * K + z * (K ** 2)
				x_new = (x + 1) % K
				newind1 = (x + 1) % K + y * K + z * (K ** 2)
				newind2 = x + (y + x) % K * K + z * (K ** 2)
				newind3 = x + y * K + (z + y) % K * (K ** 2)
				Det[newind1][itr] += 1
				Det[newind2][itr] += 1
				Det[newind3][itr] += 1

	B = matrixPow(Det, N, K ** 3)
	# print(Det)

	start = [[0] for i in range(K ** 3)]
	start[0][0] = 1
	ans_pre = mul(B, start)
	# print(ans_pre)
	ans = 0
	for i in range(K ** 2):
		ans += ans_pre[i][0]

	ans %= mod
	print(ans)


if __name__ == '__main__':
	main()
0