n,m = map(int,input().split()) mod = 998244353 #nCk def com(n,mod): fact = [1,1] factinv = [1,1] inv = [0,1] for i in range(2,n+1): fact.append((fact[-1]*i)%mod) inv.append((-inv[mod%i]*(mod//i))%mod) factinv.append((factinv[-1]*inv[-1])%mod) return fact, factinv f,fi = com(n+m+10, mod) ans = f[m] l = m % n ans *= pow(fi[m//n], n-l, mod) ans %= mod ans *= pow(fi[m//n+1], l, mod) ans %= mod print(ans)