結果
問題 | No.1516 simple 門松列 problem Re:MASTER |
ユーザー |
![]() |
提出日時 | 2025-03-20 20:32:00 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 293 ms / 6,000 ms |
コード長 | 4,373 bytes |
コンパイル時間 | 202 ms |
コンパイル使用メモリ | 81,888 KB |
実行使用メモリ | 82,188 KB |
最終ジャッジ日時 | 2025-03-20 20:32:59 |
合計ジャッジ時間 | 3,244 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 19 |
ソースコード
MOD = 998244353 def main(): import sys N, K = map(int, sys.stdin.readline().split()) # Generate all possible states (a, b) where a != b states = [] for a in range(K): for b in range(K): if a != b: states.append((a, b)) state_count = len(states) state_idx = {(a, b): i for i, (a, b) in enumerate(states)} # Initialize the transition matrix matrix = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)] for i in range(state_count): a, b = states[i] for c in range(K): if a == c or b == c or a == b: continue # Check if Kadomatsu condition holds for (a, b, c) # Check all are distinct and b is max or min if (b > a and b > c) or (b < a and b < c): # Valid transition to (b, c) if (b, c) in state_idx: j = state_idx[(b, c)] matrix[i][j] = ( (matrix[i][j][0] + 1) % MOD, (matrix[i][j][1] + c) % MOD ) # Initialize the initial vector (count and sum) initial_count = [0] * state_count initial_sum = [0] * state_count for a2 in range(K): for a3 in range(K): if a2 == a3: continue total = 0 sum_a1 = 0 # case1: a2 is maximum (a2 > a3) if a2 > a3: # a1 must be < a2 and != a3 # iterate all a1 < a2, and a1 != a3 for a1 in range(a2): if a1 != a3: total += 1 sum_a1 += a1 # case2: a2 is minimum (a2 < a3) if a2 < a3: # a1 must be > a2 and != a3 for a1 in range(a2 + 1, K): if a1 != a3: total += 1 sum_a1 += a1 if total == 0: continue # The sum_initial is sum(a1 + a2 + a3 for valid a1) sum_total = (sum_a1 + (a2 + a3) * total) % MOD if (a2, a3) in state_idx: idx = state_idx[(a2, a3)] initial_count[idx] = (initial_count[idx] + total) % MOD initial_sum[idx] = (initial_sum[idx] + sum_total) % MOD m = N - 3 if m < 0: m = 0 # Function to multiply two matrices def multiply_matrix(a, b): res = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)] for i in range(state_count): for k in range(state_count): a_count, a_sum = a[i][k] if a_count == 0: continue for j in range(state_count): b_count, b_sum = b[k][j] if b_count == 0: continue total_count = (a_count * b_count) % MOD total_sum = (a_count * b_sum + a_sum * b_count) % MOD res[i][j] = ( (res[i][j][0] + total_count) % MOD, (res[i][j][1] + total_sum) % MOD ) return res # Function to apply matrix exponentiation def matrix_power(mat, power): result = [[(1 if i == j else 0, 0) for j in range(state_count)] for i in range(state_count)] while power > 0: if power % 2 == 1: result = multiply_matrix(result, mat) mat = multiply_matrix(mat, mat) power //= 2 return result # Apply matrix exponentiation if m == 0: final_count = initial_count final_sum = initial_sum else: mat_pow = matrix_power(matrix, m) final_count = [0] * state_count final_sum = [0] * state_count for i in range(state_count): cnt = initial_count[i] s = initial_sum[i] for j in range(state_count): trans_count, trans_sum = mat_pow[i][j] final_count[j] = (final_count[j] + cnt * trans_count) % MOD final_sum[j] = (final_sum[j] + s * trans_count + cnt * trans_sum) % MOD ans1 = sum(final_count) % MOD ans2 = sum(final_sum) % MOD print(ans1, ans2) if __name__ == "__main__": main()