import sys def input(): return sys.stdin.readline().strip() def mapint(): return map(int, input().split()) sys.setrecursionlimit(10**9) N, M = mapint() mod = 998244353 dp = [0]*(M+1) dp[0] = 1 if N==1: print(1) exit() for i in range(1, M+1): dp[i] = dp[i-1] if i>=N: dp[i] += dp[i-N] dp[i] %= mod print(dp[-1])