N, K = map(int, input().split()) mod = 998244353 inv = [pow(i,mod-2,mod) for i in range(N+1)] dp = [[0]*(N+2) for _ in range(K+1)] dp[0][0] = 1 dp[0][1] = -1 for i in range(1,K+1): for j in range(1,N+2): dp[i-1][j] += dp[i-1][j-1] dp[i-1][j] %= mod for j in range(N+1): if dp[i-1][j]>0: dp[i][j+1] += inv[N-j]*dp[i-1][j] dp[i][j+1] %= mod dp[i][-1] -= inv[N-j]*dp[i-1][j] dp[i][-1] %= mod res = 0 for i in range(1,N): dp[-1][i] += dp[-1][i-1] dp[-1][i] %= mod res += dp[-1][i] print((1-res)%mod)