n, m = map(int, input().split()) mod = 998244353 # 1 * nで # n * mを埋める # nの使い方 # まずたて * m # そのうちm - n, m - n - nを横にする。 # たてnなので横の列は横におくしかない # 横の種類数はm // nこ # n == 1 の時 if n == 1: print(1) exit() fact = [1] * (m + 5) for i in range(1, m + 5): fact[i] = fact[i - 1] * i % mod def comb(n, k): return fact[n] * pow(fact[n - k] * fact[k] % mod, mod - 2, mod) % mod ans = 0 for i in range(m // n + 1): tot = m - n * i + i ans = (ans + comb(tot, i)) % mod print(ans)