#include #include #define rep(i,n) for(int i=0;i vi; typedef vector vl; typedef vector> vvi; typedef vector> vvl; typedef long double ld; typedef pair P; ostream& operator<<(ostream& os, const modint& a) {os << a.val(); return os;} template ostream& operator<<(ostream& os, const static_modint& a) {os << a.val(); return os;} template istream& operator>>(istream& is, vector& v){int n = v.size(); assert(n > 0); rep(i, n) is >> v[i]; return is;} template ostream& operator<<(ostream& os, const pair& p){os << p.first << ' ' << p.second << "\n"; return os;} template ostream& operator<<(ostream& os, const vector& v){int n = v.size(); rep(i, n) os << v[i] << (i == n - 1 ? "\n" : " "); return os;} template ostream& operator<<(ostream& os, const vector>& v){int n = v.size(); rep(i, n) os << v[i] << (i == n - 1 ? "\n" : ""); return os;} using mint = modint998244353; int main(){ int n, k; cin >> n >> k; vector dp(n + 1); dp[0] = 1; mint ans = 0; mint rev = mint(1) / n; rep(_, k){ vector dp_old(n + 1); swap(dp, dp_old); rep(i, n + 1){ if(i > 0){ mint plus = dp_old[i] * i * rev; dp[i - 1] += plus; ans += plus; } if(i < n) dp[i + 1] += dp_old[i] * (n - i) * rev; } } cout << ans + n; return 0; }