MOD = 998244353 N, M = map(int, input().split()) ans = 0 def s(x): return (x * (x + 1) // 2) % MOD for m in range(1, M + 1): tmp = (s(m) * pow(m, N - 1, MOD) - s(m - 1) * pow(m - 1, N - 1, MOD) - (s(M) - s(m - 1)) * pow(M - m + 1, N - 1, MOD) + (s(M) - s(m)) * pow(M - m, N - 1, MOD)) % MOD * m ans = (ans + tmp * N) % MOD print(ans)