#include using namespace std; //* #include using namespace atcoder; using mint = modint998244353; // using mint = modint1000000007; //*/ using ll = long long; using i128 = __int128_t; using pll = pair; using tlll = tuple; using ld = long double; const int INF = 1000100100; const ll INFF = 1000100100100100100LL; const int dx[4] = {0, 0, 1, -1}, dy[4] = {1, -1, 0, 0}; #define overload4(_1, _2, _3, _4, name, ...) name #define rep1(i, n) for (ll i = 0; i < ll(n); i++) #define rep2(i, l, r) for (ll i = ll(l); i < ll(r); i++) #define rep3(i, l, r, d) for (ll i = ll(l); (d) > 0 ? i < ll(r) : i > ll(r); i += d) #define rep(...) overload4(__VA_ARGS__, rep3, rep2, rep1)(__VA_ARGS__) #define per(i, n) for (int i = (n) - 1; i >= 0; --i) #define yesno(f) cout << (f ? "Yes" : "No") << endl; #define YESNO(f) cout << (f ? "YES" : "NO") << endl; #define all(a) (a).begin(), (a).end() #define popc(x) __builtin_popcountll(ll(x)) template ostream &operator<<(ostream &os, const pair &p) { return os << p.first << ' ' << p.second; } template void printvec(const vector &v) { int n = v.size(); rep(i, n) cout << v[i] << (i == n - 1 ? "" : " "); cout << '\n'; } template void printvect(const vector &v) { for (auto &vi : v) cout << vi << '\n'; } template void printvec2(const vector> &v) { for (auto &vi : v) printvec(vi); } template bool chmax(S &x, const T &y) { return (x < y) ? (x = y, true) : false; } template bool chmin(S &x, const T &y) { return (x > y) ? (x = y, true) : false; } #ifdef LOCAL // https://zenn.dev/sassan/articles/19db660e4da0a4 #include "cpp-dump-main/dump.hpp" #define dump(...) cpp_dump(__VA_ARGS__) CPP_DUMP_DEFINE_DANGEROUS_EXPORT_OBJECT(val()); #else #define dump(...) #endif struct io_setup { io_setup() { ios_base::sync_with_stdio(false); cin.tie(NULL); cout << fixed << setprecision(15); } } io_setup; void solve() { ll n, m; cin >> n >> m; mint ans = mint(m).pow(n) * n; rep(i, 2, m + 1) { mint res = mint(m + 1 - i) / m; ans += mint(m).pow(n) * (m + 1 - i) / (i - 1) * (1 - mint(res).pow(n)); } cout << ans.val() << endl; } int main() { int t; // cin >> t; t = 1; while (t--) { solve(); } }