MOD = 998244353 N, M = map(int, input().split()) if N > M: print(1 % MOD) else: from collections import defaultdict sum_prev = defaultdict(int) sum_prev[0] = 1 # Initialize for i=0 dp0_prev = 1 dp1_prev = 0 for i in range(1, M + 1): current_dp0 = (dp0_prev + dp1_prev) % MOD r = i % N current_dp1 = 0 if i >= N: current_dp1 = sum_prev.get(r, 0) % MOD # Update sum_prev with current_dp0 for the current remainder sum_prev[r] = (sum_prev.get(r, 0) + current_dp0) % MOD # Update dp0_prev and dp1_prev for the next iteration dp0_prev, dp1_prev = current_dp0, current_dp1 print((dp0_prev + dp1_prev) % MOD)