n,k = map(int,input().split()) MOD = 998244353 dp = [[int(i!=j) for j in range(k)] for i in range(k)] for _ in range(n-2): ndp = [[0]*(k+1) for _ in range(k)] for i in range(k): for j in range(i+1,k): ndp[j][0] += dp[i][j] ndp[j][i] -= dp[i][j] ndp[j][i+1] += dp[i][j] ndp[j][j] -= dp[i][j] for j in range(i): ndp[j][j+1] += dp[i][j] ndp[j][i] -= dp[i][j] ndp[j][i+1] += dp[i][j] ndp[j][k] -= dp[i][j] dp = ndp for i in range(k): dp[i][0] %= MOD for j in range(1,k+1): dp[i][j] = (dp[i][j]+dp[i][j-1])%MOD v = sum(sum(i) for i in dp)%MOD w = v*(k-1)*n*(MOD+1)//2%MOD print(v,w)