N, M = map(int, input().split())
MOD = 998244353


def factorial(x):
    fac = 1
    for i in range(1, x + 1):
        fac *= i
        fac %= MOD
    return fac


d, m = divmod(M, N)

fac_m = factorial(M)
fac_d = factorial(d)

ans = fac_m * pow(fac_d, -(N - m), MOD) * pow(fac_d * (d + 1), -m, MOD) % MOD
print(ans)