#include using namespace std; #include typedef atcoder::modint998244353 mint; int main(){ int n, k; cin >> n >> k; mint ans = n; mint ninv = mint(n).inv(); vector dp(k+1, vector(n+1)); dp[0][0] = 1; for (int i=0; i 0){ dp[i+1][j-1] += mint(j) * ninv * dp[i][j]; ans += mint(j) * ninv * dp[i][j]; } // select other if (j < n){ dp[i+1][j+1] += mint(n-j) * ninv * dp[i][j]; } } } cout << ans.val() << '\n'; }