n, m = map(int, input().split())
mod = 998244353
fact, inv, factinv = [1, 1], [0, 1], [1, 1]
for i in range(2, m+1):
    fact.append(fact[-1]*i%mod)
    inv.append(-inv[mod%i]*(mod//i)%mod)
    factinv.append(factinv[-1]*inv[-1]%mod)
def cmb(n, r):
    if r < 0 or n < r: return 0
    return fact[n]*factinv[r]%mod*factinv[n-r]%mod
ans = 1
for i in range(m%n): ans = ans*cmb(m-i*(m//n+1), m//n+1)%mod
for i in range(n-m%n): ans = ans*cmb(m-m%n*(m//n+1)-i*(m//n), m//n)%mod
print(ans)