lim = 1 << 20 mod = 998244353 fac = [1] * lim finv = [1] * lim for i in range(2, lim): fac[i] = fac[i-1] * i % mod finv[lim-1] = pow(fac[lim-1], -1, mod) for i in reversed(range(2, lim)): finv[i-1] = finv[i] * i % mod def C(a, b): if 0 <= b <= a: return fac[a] * finv[b] % mod * finv[a - b] % mod else: return 0 n, k = map(int, input().split()) ans = 0 for i in range(-n, n+1): ans += C(2*n, n+i*(k+2)) - C(2*n, n-k-1+i*(k+2)) ans -= C(2*n, n+i*(k+1)) - C(2*n, n-k+i*(k+1)) ans %= mod print(ans)