N, M = map(int, input().split()) if N == 1: print(1) exit() dp = [0] * (M + 1) for i in range(M + 1): if i == 0: dp[i] = 1 elif 0 < i and i < N: dp[i] = dp[i - 1] else: dp[i] = (dp[i - 1] + dp[i - N]) % 998244353 print(dp[M])