N, M = map(int, input().split()) mod = 998244353 dp = [0] * (M + 1) dp[0] = 1 for i in range(M): dp[i + 1] += dp[i] dp[i + 1] %= mod if i + N <= M and N != 1: dp[i + N] += dp[i] dp[i + N] %= mod print(dp[-1])