MOD = 998244353 N, M = map(int, input().split()) ans = 0 def s(x): return (x * (x + 1) // 2) % MOD for m in range(1, M + 1): tmp = s(m) * pow(m, N - 1, MOD) - s(m - 1) * pow(m - 1, N - 1, MOD) tmp %= MOD ans = (ans + m * tmp % MOD * N) % MOD tmp = (s(M) - s(m - 1)) * pow(M - m + 1, N - 1, MOD) - (s(M) - s(m)) * pow(M - m, N - 1, MOD) tmp %= MOD ans = (ans - m * tmp % MOD * N) % MOD print(ans)