def factorial_mod(n,mod): ret = 1 for i in range(1,n+1): ret *= i ret %= mod return ret def comb_mod(n,r,mod): if r > n or r < 0: ret = 0 else: fact_n = factorial_mod(n, mod) fact_r = factorial_mod(r, mod) fact_nr = factorial_mod(n-r, mod) ret = fact_n * pow(fact_r, mod-2, mod) * pow(fact_nr, mod-2, mod) % mod return ret n,m = map(int,input().split()) mod = 998244353 ans = 1 if m >= n: for i in range(1,m//n+1): k = m - n * i if i > k+1: ans += 1 else: ans += comb_mod(k+1,i,mod) ans %= mod print(ans)