mod = 998244353 n, m = map(int, input().split()) if n == 1: print(1) exit() dp = [0] * (m + 1) dp[0] = 1 for i in range(1, m + 1): if i < n: dp[i] = dp[i - 1] else: dp[i] = dp[i - 1] + dp[i - n] dp[i] %= mod print(dp[m])