N, K = map(int, input().split()) MOD = 998244353 dp = [[[[0, 0] for _ in range(K)] for _ in range(K)] for _ in range(N)] for i in range(K): for j in range(K): if i!=j: dp[1][i][j] = [1, i+j] for i in range(2, N): acc0 = [[0] for _ in range(K)] acc1 = [[0] for _ in range(K)] for j in range(K): for k in range(K): acc0[j].append((acc0[j][-1]+dp[i-1][k][j][0])%MOD) acc1[j].append((acc1[j][-1]+dp[i-1][k][j][1])%MOD) for j in range(K): for k in range(K): if j>k: dp[i][j][k][0] += acc0[j][j]-dp[i-1][k][j][0] dp[i][j][k][1] += acc1[j][j]-dp[i-1][k][j][1]+k*dp[i][j][k][0] elif j