n, m = map(int, input().split()) if (not 1 <= n <= 10 ** 9) or (not 1 <= m <= 10 ** 6): exit(print(1 % 0)) mod = 998244353 if n == 1: print(1) elif n > m: print(1) elif n == m: print(2) else: a = [1] * (n - 1) + [2] for i in range(10 ** 6 + 1): a.append((a[-1] + a[-n]) % mod) print(a[m - 1])