import sys; input = sys.stdin.buffer.readline sys.setrecursionlimit(10**7) from collections import defaultdict con = 998244353; INF = float("inf") def getlist(): return list(map(int, input().split())) #処理内容 def main(): N, M = getlist() DP = [0] * (M + 1) DP[0] = 1 if N == 1: print(1) return for i in range(1, M + 1): if i >= N: DP[i] = (DP[i - 1] + DP[i - N]) % con else: DP[i] = DP[i - 1] # print(DP) print(DP[-1]) if __name__ == '__main__': main()