#region Header #!/usr/bin/env python3 # from typing import * import sys import io import math import collections import decimal import itertools from queue import PriorityQueue import bisect import heapq def input(): return sys.stdin.readline()[:-1] sys.setrecursionlimit(1000000) #endregion # _INPUT = """# paste here... # """ # sys.stdin = io.StringIO(_INPUT) MOD = 998244353 def solve(N, M): if M < N: return 1 elif M == N: return 2 else: dp = [0 for _ in range(M+1)] dp[0] = 1 for i in range(M): dp[i+1] = (dp[i+1] + dp[i]) % MOD if i+N <= M: dp[i+N] = (dp[i+N] + dp[i]) % MOD return dp[M] def main(): N, M = map(int, input().split()) print(solve(N, M)) if __name__ == '__main__': main()