N, M = map(int, input().split()) mod = 998244353 if N == 1: print(1) exit() dp = [0] * (M + 2) dp[0] = 1 for i in range(M): dp[i + 1] += dp[i] dp[i + 1] %= mod if i + N <= M: dp[i + N] += dp[i] dp[i + N] %= mod print(dp[M])