n, m = map(int, input().split()) mod = 998244353 if n == 1: print(1) elif n > m: print(1) elif n == m: print(2) else: a = [1] * (n - 1) + [2] for i in range(10 ** 6 + 1): a.append((a[-1] + a[-n]) % mod) print(a[m - 1])