N, M = map(int, input().split()) mod = 998244353 if N == 1 or N > M: print(1) exit() if N == M: print(2) exit() dp = [0] * (M + 2) dp[0] = 1 if N == 2: for i in range(M + 1): dp[i + 1] = (dp[i] % mod + dp[i - 1] % mod) % mod else: dp[1], dp[2] = 1, 1 for i in range(3, M + 1): dp[i] = (dp[i - 1]%mod + dp[i - 3]%mod)%mod print(dp[M])