#int(input()) #map(int, input().split()) #list(map(int, input().split())) N, M = map(int, input().split()) mod = 998244353 if N == 1: print("1") exit() dp = [0] * (M+1) dp[0] = 1 for i in range(M): dp[i+1] = (dp[i+1] + dp[i]) % mod if i + N <= M: dp[i+N] = (dp[i+N] + dp[i]) % mod print(dp[M])