#include using namespace std; // --- Constants and basic modular functions --- const long long MOD = 998244353; // Fast modular exponentiation long long modexp(long long base, long long exp, long long mod = MOD) { long long res = 1 % mod; base %= mod; while(exp > 0) { if(exp & 1) res = (res * base) % mod; base = (base * base) % mod; exp >>= 1; } return res; } // Modular inverse (using Fermat’s little theorem; MOD is prime) long long modinv(long long x, long long mod = MOD) { return modexp(x, mod-2, mod); } // --- Function to compute power-sum via Lagrange interpolation --- // // We wish to compute F(p, n) = sum_{d=1}^{n} d^p mod MOD. // (It is known that F(p,x) is a polynomial in x of degree p+1.) // If n is small (n <= p+1) we compute directly; otherwise we interpolate. long long powerSum(long long p, long long n) { if(p == 0) return n % MOD; // sum_{d=1}^{n}1 = n. if(n <= p+1) { long long s = 0; for (int d = 1; d <= n; d++){ s = (s + modexp(d, p)) % MOD; } return s; } int m = p + 1; // We will use points 0,1,...,m. vector y(m+1, 0); y[0] = 0; for (int i = 1; i <= m; i++){ y[i] = (y[i-1] + modexp(i, p)) % MOD; } // Precompute factorials and inverse factorials up to m. vector fact(m+1, 0), invfact(m+1, 0); fact[0] = 1; for (int i = 1; i <= m; i++){ fact[i] = (fact[i-1] * i) % MOD; } invfact[m] = modinv(fact[m]); for (int i = m; i >= 1; i--){ invfact[i-1] = (invfact[i] * i) % MOD; } // Precompute prefix (pre) and suffix (suf) products: // pre[i] = ∏_{j=0}^{i-1} (n - j), for i = 0,..., m+1, with pre[0]=1. vector pre(m+2, 0), suf(m+2, 0); pre[0] = 1; for (int i = 0; i <= m; i++){ pre[i+1] = ( pre[i] * ((n - i) % MOD) ) % MOD; } suf[m+1] = 1; for (int i = m; i >= 0; i--){ suf[i] = ( suf[i+1] * ((n - i) % MOD) ) % MOD; } long long res = 0; // Lagrange interpolation formula: // For j = 0,..., m, the Lagrange basis is // L_j(n) = pre[j] * suf[j+1] * inv( fact[j] * fact[m-j] ) * (-1)^(m - j) // and then F(p,n) = sum_{j=0}^{m} y[j] * L_j(n). for (int j = 0; j <= m; j++){ long long term = y[j]; term = (term * pre[j]) % MOD; term = (term * suf[j+1]) % MOD; long long denom = (fact[j] * fact[m - j]) % MOD; if(((m - j) & 1) == 1) // multiply by (-1)^(m - j) denom = (MOD - denom) % MOD; term = (term * modinv(denom)) % MOD; res = (res + term) % MOD; } return res; } // --- Main --- // // We are given integers N and K. For a permutation P of {1,2,…,N} we defined // f(P) = ∑_{i=1}^{N-1} |P_i - P_{i+1}|^K. // We showed that // S = ∑_{P} f(P) = 2*(N-1)! * [ N*(∑_{d=1}^{N-1} d^K) - (∑_{d=1}^{N-1} d^(K+1)) ] mod MOD. // Also, note that if (N-1)! has a factor of MOD then it is 0 mod MOD – that is, // if N-1 ≥ MOD (i.e. if N ≥ MOD+1) then S ≡ 0. int main(){ ios::sync_with_stdio(false); cin.tie(nullptr); long long N, K; cin >> N >> K; // If (N-1)! contains MOD then S ≡ 0. if(N >= MOD + 1) { cout << 0 << "\n"; return 0; } // Compute A = (N-1)! mod MOD. long long A = 1; // (If N is very small – e.g. up to about 10^6 – a simple loop is fast.) for (int i = 1; i < N; i++){ A = (A * i) % MOD; } // We now need S1 = ∑_{d=1}^{N-1} d^K and S2 = ∑_{d=1}^{N-1} d^(K+1) modulo MOD. long long nVal = N - 1; long long S1 = powerSum(K, nVal); long long S2 = powerSum(K + 1, nVal); // Our inner bracket is: N*S1 - S2 (computed mod MOD) long long inner = ((N % MOD) * S1) % MOD; inner = (inner - S2) % MOD; if(inner < 0) inner += MOD; long long ans = (2 * A) % MOD; ans = (ans * inner) % MOD; cout << ans % MOD << "\n"; return 0; }