結果

問題 No.1516 simple 門松列 problem Re:MASTER
ユーザー lam6er
提出日時 2025-03-20 20:32:00
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 293 ms / 6,000 ms
コード長 4,373 bytes
コンパイル時間 202 ms
コンパイル使用メモリ 81,888 KB
実行使用メモリ 82,188 KB
最終ジャッジ日時 2025-03-20 20:32:59
合計ジャッジ時間 3,244 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    N, K = map(int, sys.stdin.readline().split())
    
    # Generate all possible states (a, b) where a != b
    states = []
    for a in range(K):
        for b in range(K):
            if a != b:
                states.append((a, b))
    state_count = len(states)
    state_idx = {(a, b): i for i, (a, b) in enumerate(states)}
    
    # Initialize the transition matrix
    matrix = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)]
    for i in range(state_count):
        a, b = states[i]
        for c in range(K):
            if a == c or b == c or a == b:
                continue
            # Check if Kadomatsu condition holds for (a, b, c)
            # Check all are distinct and b is max or min
            if (b > a and b > c) or (b < a and b < c):
                # Valid transition to (b, c)
                if (b, c) in state_idx:
                    j = state_idx[(b, c)]
                    matrix[i][j] = (
                        (matrix[i][j][0] + 1) % MOD,
                        (matrix[i][j][1] + c) % MOD
                    )
    
    # Initialize the initial vector (count and sum)
    initial_count = [0] * state_count
    initial_sum = [0] * state_count
    
    for a2 in range(K):
        for a3 in range(K):
            if a2 == a3:
                continue
            total = 0
            sum_a1 = 0
            # case1: a2 is maximum (a2 > a3)
            if a2 > a3:
                # a1 must be < a2 and != a3
                # iterate all a1 < a2, and a1 != a3
                for a1 in range(a2):
                    if a1 != a3:
                        total += 1
                        sum_a1 += a1
            # case2: a2 is minimum (a2 < a3)
            if a2 < a3:
                # a1 must be > a2 and != a3
                for a1 in range(a2 + 1, K):
                    if a1 != a3:
                        total += 1
                        sum_a1 += a1
            if total == 0:
                continue
            # The sum_initial is sum(a1 + a2 + a3 for valid a1)
            sum_total = (sum_a1 + (a2 + a3) * total) % MOD
            if (a2, a3) in state_idx:
                idx = state_idx[(a2, a3)]
                initial_count[idx] = (initial_count[idx] + total) % MOD
                initial_sum[idx] = (initial_sum[idx] + sum_total) % MOD
    
    m = N - 3
    if m < 0:
        m = 0
    
    # Function to multiply two matrices
    def multiply_matrix(a, b):
        res = [[(0, 0) for _ in range(state_count)] for __ in range(state_count)]
        for i in range(state_count):
            for k in range(state_count):
                a_count, a_sum = a[i][k]
                if a_count == 0:
                    continue
                for j in range(state_count):
                    b_count, b_sum = b[k][j]
                    if b_count == 0:
                        continue
                    total_count = (a_count * b_count) % MOD
                    total_sum = (a_count * b_sum + a_sum * b_count) % MOD
                    res[i][j] = (
                        (res[i][j][0] + total_count) % MOD,
                        (res[i][j][1] + total_sum) % MOD
                    )
        return res
    
    # Function to apply matrix exponentiation
    def matrix_power(mat, power):
        result = [[(1 if i == j else 0, 0) for j in range(state_count)] for i in range(state_count)]
        while power > 0:
            if power % 2 == 1:
                result = multiply_matrix(result, mat)
            mat = multiply_matrix(mat, mat)
            power //= 2
        return result
    
    # Apply matrix exponentiation
    if m == 0:
        final_count = initial_count
        final_sum = initial_sum
    else:
        mat_pow = matrix_power(matrix, m)
        final_count = [0] * state_count
        final_sum = [0] * state_count
        for i in range(state_count):
            cnt = initial_count[i]
            s = initial_sum[i]
            for j in range(state_count):
                trans_count, trans_sum = mat_pow[i][j]
                final_count[j] = (final_count[j] + cnt * trans_count) % MOD
                final_sum[j] = (final_sum[j] + s * trans_count + cnt * trans_sum) % MOD
    
    ans1 = sum(final_count) % MOD
    ans2 = sum(final_sum) % MOD
    print(ans1, ans2)

if __name__ == "__main__":
    main()
0