結果
| 問題 |
No.1516 simple 門松列 problem Re:MASTER
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:32:00 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 293 ms / 6,000 ms |
| コード長 | 4,373 bytes |
| コンパイル時間 | 202 ms |
| コンパイル使用メモリ | 81,888 KB |
| 実行使用メモリ | 82,188 KB |
| 最終ジャッジ日時 | 2025-03-20 20:32:59 |
| 合計ジャッジ時間 | 3,244 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 19 |
ソースコード
MOD = 998244353
def main():
import sys
N, K = map(int, sys.stdin.readline().split())
# Generate all possible states (a, b) where a != b
states = []
for a in range(K):
for b in range(K):
if a != b:
states.append((a, b))
state_count = len(states)
state_idx = {(a, b): i for i, (a, b) in enumerate(states)}
# Initialize the transition matrix
matrix = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)]
for i in range(state_count):
a, b = states[i]
for c in range(K):
if a == c or b == c or a == b:
continue
# Check if Kadomatsu condition holds for (a, b, c)
# Check all are distinct and b is max or min
if (b > a and b > c) or (b < a and b < c):
# Valid transition to (b, c)
if (b, c) in state_idx:
j = state_idx[(b, c)]
matrix[i][j] = (
(matrix[i][j][0] + 1) % MOD,
(matrix[i][j][1] + c) % MOD
)
# Initialize the initial vector (count and sum)
initial_count = [0] * state_count
initial_sum = [0] * state_count
for a2 in range(K):
for a3 in range(K):
if a2 == a3:
continue
total = 0
sum_a1 = 0
# case1: a2 is maximum (a2 > a3)
if a2 > a3:
# a1 must be < a2 and != a3
# iterate all a1 < a2, and a1 != a3
for a1 in range(a2):
if a1 != a3:
total += 1
sum_a1 += a1
# case2: a2 is minimum (a2 < a3)
if a2 < a3:
# a1 must be > a2 and != a3
for a1 in range(a2 + 1, K):
if a1 != a3:
total += 1
sum_a1 += a1
if total == 0:
continue
# The sum_initial is sum(a1 + a2 + a3 for valid a1)
sum_total = (sum_a1 + (a2 + a3) * total) % MOD
if (a2, a3) in state_idx:
idx = state_idx[(a2, a3)]
initial_count[idx] = (initial_count[idx] + total) % MOD
initial_sum[idx] = (initial_sum[idx] + sum_total) % MOD
m = N - 3
if m < 0:
m = 0
# Function to multiply two matrices
def multiply_matrix(a, b):
res = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)]
for i in range(state_count):
for k in range(state_count):
a_count, a_sum = a[i][k]
if a_count == 0:
continue
for j in range(state_count):
b_count, b_sum = b[k][j]
if b_count == 0:
continue
total_count = (a_count * b_count) % MOD
total_sum = (a_count * b_sum + a_sum * b_count) % MOD
res[i][j] = (
(res[i][j][0] + total_count) % MOD,
(res[i][j][1] + total_sum) % MOD
)
return res
# Function to apply matrix exponentiation
def matrix_power(mat, power):
result = [[(1 if i == j else 0, 0) for j in range(state_count)] for i in range(state_count)]
while power > 0:
if power % 2 == 1:
result = multiply_matrix(result, mat)
mat = multiply_matrix(mat, mat)
power //= 2
return result
# Apply matrix exponentiation
if m == 0:
final_count = initial_count
final_sum = initial_sum
else:
mat_pow = matrix_power(matrix, m)
final_count = [0] * state_count
final_sum = [0] * state_count
for i in range(state_count):
cnt = initial_count[i]
s = initial_sum[i]
for j in range(state_count):
trans_count, trans_sum = mat_pow[i][j]
final_count[j] = (final_count[j] + cnt * trans_count) % MOD
final_sum[j] = (final_sum[j] + s * trans_count + cnt * trans_sum) % MOD
ans1 = sum(final_count) % MOD
ans2 = sum(final_sum) % MOD
print(ans1, ans2)
if __name__ == "__main__":
main()
lam6er