n,m = map(int,input().split()) mod = 998244353 facts = [1]*(m+2) for i in range(2,m+1): facts[i] = facts[i-1]*i%mod ans = facts[m] # print(ans) d,mo = divmod(m,n) base = 1 for i in range(n): if i < mo: base *= facts[d+1] else: base *= facts[d] base %= mod ans *= pow(base,mod-2,mod) print(ans%mod)