from functools import lru_cache mod = 998244353 @lru_cache(maxsize=None) def div(x): return pow(x, mod-2, mod) N, K = map(int, input().split()) dp = [0] * (N + 1) dp[0] = 1 ans = 0 for j in range(K): ndp = [0] * (N + 1) for i in range(j, N): ndp[i+1] += dp[i] * div(N - i) % mod for i in range(j, N): ndp[i+1] += ndp[i] ndp[i+1] %= mod dp = ndp ans += dp[N] ans %= mod print(ans)