N,M = map(int,input().split()) import sys if N > M: print(1) exit() if N == 1: print(1) exit() P = 998244353 dp = [0] * (M + 1) dp[0] = 1 for i in range(1,M + 1): dp[i] += dp[i-1] if i >= N: dp[i] += dp[i - N] dp[i] %= P print(dp[-1])