#include using namespace std; using LL = long long int; const LL M = 998244353; LL ex(LL a, LL b) { LL ans = 1, e = a; while(b) { if(b & 1) { (ans *= e) %= M; } (e *= e) %= M; b >>= 1; } return ans; } int main() { LL n, k; cin >> n >> k; LL ans = ex(k, n + 1); for(LL i = 1; i <= k - 1; i++) { (ans -= (k * n - i * (n - 1)) % M * ex(i, n - 1)) %= M; } ans = (ans + M) % M; cout << ans << endl; }