#include using namespace std; const long long MOD = 998244353; long long modpow(long long a, long long b){ long long ans = 1; while (b > 0){ if (b % 2 == 1){ ans *= a; ans %= MOD; } a *= a; a %= MOD; b /= 2; } return ans; } int main(){ int N, K; cin >> N >> K; long long ans = 0; for (int i = 1; i <= K; i++){ long long add = modpow(i, N - 1) - modpow(i - 1, N - 1) + MOD; add *= K - i; add %= MOD; add *= N; add %= MOD; add *= i; add %= MOD; ans += add; } for (int i = 1; i <= K; i++){ long long add = modpow(i, N); add += MOD - modpow(i - 1, N); add += (MOD - modpow(i - 1, N - 1)) * N; add %= MOD; add *= i; add %= MOD; ans += add; } ans %= MOD; cout << ans << endl; }