n, m = map(int, input().split()) mod = 998244353 if n == 1: print(1) elif n > m: print(1) elif n == m: print(2) else: ans = [1] * (n - 1) +[2] for i in range(10**6+1): ans.append((ans[-1] + ans[-n]) % mod) print(ans[m-1])