mod = 998244353 n, m = map(int, input().split()) if n == 1: exit(print(1)) dp = [0] * (m + 1) dp[0] = 1 for i in range(1, m + 1): dp[i] += dp[i - 1] if i >= n: dp[i] += dp[i - n] dp[i] %= mod print(dp[m])