def matmul(A,B): # A,B: 行列 res = [[0]*len(B[0]) for _ in [None]*len(A)] for i, resi in enumerate(res): for k, aik in enumerate(A[i]): for j,bkj in enumerate(B[k]): resi[j] += aik*bkj resi[j] %= MOD return res def matpow(A,p): #A^p mod M if p%2: return matmul(A, matpow(A,p-1)) elif p > 0: b = matpow(A,p//2) return matmul(b,b) else: return [[1 if i == j else 0 for j in range(len(A))] for i in range(len(A))] n,k = map(int,input().split()) MOD = 998244353 pk = [k**i for i in range(4)] A = [[0]*pk[3] for _ in range(pk[3])] from itertools import product for q,r,s in product(range(k),repeat=3): v = q*pk[2] + r*pk[1] + s for qq,rr,ss in [((1+q)%k,r,s),(q,(q+r)%k,s),(q,r,(r+s)%k)]: w = qq*pk[2] + rr*pk[1] + ss A[w][v] += 1 v = [[0]*pk[3]] v[0][0] = 1 v = matmul(v,matpow(A,n)) print(sum(v[0][::k])%MOD)