mod = 998244353 n, k = map(int,input().split()) ans = n ninv = pow(n, mod-2, mod) dp = [[0] * (n+1) for i in range(k+1)] dp[0][0] = 1 for i in range(k): for j in range(n+1): if j > 0: dp[i+1][j-1] += j * ninv % mod * dp[i][j] dp[i+1][j-1] %= mod ans += j * ninv % mod * dp[i][j] % mod if j < n: dp[i+1][j+1] += (n-j) * ninv % mod * dp[i][j] dp[i+1][j+1] %= mod ans %= mod print(ans)