n, m = map(int, input().split())

mod = 998244353

if n == 1:
    print(1)
elif n > m:
    print(1)
elif n == m:
    print(2)
else:
    ans = [1] * (n - 1) +[2]
    for i in range(10**6+1):
        ans.append((ans[-1] + ans[-n]) % mod)
    print(ans[m-1])