#include const int Mod = 998244353; long long div_mod(long long x, long long y, long long z) { if (x % y == 0) return x / y; else return (div_mod((1 + x / y) * y - x, (z % y), y) * z + x) / y; } long long pow_mod(int n, long long k) { long long N, ans = 1; for (N = n; k > 0; k >>= 1, N = N * N % Mod) if (k & 1) ans = ans * N % Mod; return ans; } int main() { int N, M; scanf("%d %d", &N, &M); int i; long long ans = N, tmp; for (i = 1; i < M; i++) { tmp = div_mod(i, M, Mod); ans += div_mod((pow_mod(tmp, N + 1) - tmp + Mod) % Mod, tmp - 1, Mod); } printf("%lld\n", ans % Mod * pow_mod(M, N) % Mod); fflush(stdout); return 0; }