MOD = 998244353 n, m = map(int, input().split()) if n == 1: print(1) exit() dp = [0] * (m + 1) dp[0] = 1 for i in range(m): dp[i+1] += dp[i] dp[i+1] %= MOD if i + n <= m: dp[i+n] += dp[i] dp[i+n] %= MOD print(dp[m])