import sys; input = sys.stdin.buffer.readline sys.setrecursionlimit(10**7) from collections import defaultdict mod = 998244353 def getlist(): return list(map(int, input().split())) #A*B def mul(A, B): C = [[0] * len(B[0]) for i in range(len(A))] for i in range(len(A)): for k in range(len(B)): for j in range(len(B[0])): C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod return C #A**n 繰り返し二乗法の要領で計算する N:行列の縦横の大きさ def matrixPow(A, n, N): #B:単位行列 演算の種類によって初期化方法を変える必要もある B = [[0] * N for i in range(N)] for i in range(N): B[i][i] = 1 while n > 0: if n & 1 == 1: B = mul(A, B) A = mul(A, A) n = n >> 1 return B #処理内容 def main(): N, K = getlist() Det = [[0] * (K ** 3) for i in range(K ** 3)] for x in range(K): for y in range(K): for z in range(K): itr = x + y * K + z * (K ** 2) x_new = (x + 1) % K newind1 = (x + 1) % K + y * K + z * (K ** 2) newind2 = x + (y + x) % K * K + z * (K ** 2) newind3 = x + y * K + (z + y) % K * (K ** 2) Det[newind1][itr] += 1 Det[newind2][itr] += 1 Det[newind3][itr] += 1 B = matrixPow(Det, N, K ** 3) # print(Det) start = [[0] for i in range(K ** 3)] start[0][0] = 1 ans_pre = mul(B, start) # print(ans_pre) ans = 0 for i in range(K ** 2): ans += ans_pre[i][0] ans %= mod print(ans) if __name__ == '__main__': main()