import java.util.Scanner; public class Main { static final int MOD = 998244353; static final int MAX_V = 2000001; static long[] fac = new long[MAX_V]; static long[] finv = new long[MAX_V]; static long[] inv = new long[MAX_V]; public static void main(String[] args) { precompute(); Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int M = sc.nextInt(); long[] cumnoc = new long[M + 2]; // M + 2 to avoid IndexOutOfBounds in cumnoc[i+1] for (int i = 0; i <= M; i++) { if ((long) i * N > M) break; cumnoc[i] = fac[M - i * N + N]; cumnoc[i] = cumnoc[i] * finv[N] % MOD; cumnoc[i] = cumnoc[i] * finv[M - i * N] % MOD; } if (N == 1) { for (int i = 0; i <= M; i++) { cumnoc[i] = (cumnoc[i] - N + MOD) % MOD; } } else { cumnoc[0] = (cumnoc[0] - N + MOD) % MOD; } long[] noc = new long[M]; for (int i = 0; i < M; i++) { noc[i] = (cumnoc[i] - cumnoc[i + 1] + MOD) % MOD; } long ans = 0; for (int i = 0; i < M; i++) { ans = (ans + i * noc[i]) % MOD; } long total = cumnoc[0]; ans = ans * modinv(total) % MOD; System.out.println(ans); } static void precompute() { fac[0] = fac[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for (int i = 2; i < MAX_V; i++) { fac[i] = fac[i - 1] * i % MOD; inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD; finv[i] = finv[i - 1] * inv[i] % MOD; } } static long modinv(long a) { return modpow(a, MOD - 2); } static long modpow(long base, long exp) { long result = 1; base %= MOD; while (exp > 0) { if ((exp & 1) == 1) result = result * base % MOD; base = base * base % MOD; exp >>= 1; } return result; } }