結果

問題 No.2747 Permutation Adjacent Sum
ユーザー Naru820
提出日時 2025-02-11 18:19:54
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 4,000 bytes
コンパイル時間 1,833 ms
コンパイル使用メモリ 165,072 KB
実行使用メモリ 47,488 KB
最終ジャッジ日時 2025-02-11 18:21:17
合計ジャッジ時間 78,108 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 29 TLE * 11
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
 
// --- Constants and basic modular functions ---
 
const long long MOD = 998244353;
 
// Fast modular exponentiation
long long modexp(long long base, long long exp, long long mod = MOD) {
    long long res = 1 % mod;
    base %= mod;
    while(exp > 0) {
        if(exp & 1) res = (res * base) % mod;
        base = (base * base) % mod;
        exp >>= 1;
    }
    return res;
}
 
// Modular inverse (using Fermat’s little theorem; MOD is prime)
long long modinv(long long x, long long mod = MOD) {
    return modexp(x, mod-2, mod);
}
 
// --- Function to compute power-sum via Lagrange interpolation ---
//
// We wish to compute F(p, n) = sum_{d=1}^{n} d^p mod MOD.
// (It is known that F(p,x) is a polynomial in x of degree p+1.)
// If n is small (n <= p+1) we compute directly; otherwise we interpolate.
 
long long powerSum(long long p, long long n) {
    if(p == 0) return n % MOD; // sum_{d=1}^{n}1 = n.
    if(n <= p+1) {
        long long s = 0;
        for (int d = 1; d <= n; d++){
            s = (s + modexp(d, p)) % MOD;
        }
        return s;
    }
    int m = p + 1; // We will use points 0,1,...,m.
    vector<long long> y(m+1, 0);
    y[0] = 0;
    for (int i = 1; i <= m; i++){
        y[i] = (y[i-1] + modexp(i, p)) % MOD;
    }
    // Precompute factorials and inverse factorials up to m.
    vector<long long> fact(m+1, 0), invfact(m+1, 0);
    fact[0] = 1;
    for (int i = 1; i <= m; i++){
        fact[i] = (fact[i-1] * i) % MOD;
    }
    invfact[m] = modinv(fact[m]);
    for (int i = m; i >= 1; i--){
        invfact[i-1] = (invfact[i] * i) % MOD;
    }
    // Precompute prefix (pre) and suffix (suf) products:
    // pre[i] = ∏_{j=0}^{i-1} (n - j), for i = 0,..., m+1, with pre[0]=1.
    vector<long long> pre(m+2, 0), suf(m+2, 0);
    pre[0] = 1;
    for (int i = 0; i <= m; i++){
        pre[i+1] = ( pre[i] * ((n - i) % MOD) ) % MOD;
    }
    suf[m+1] = 1;
    for (int i = m; i >= 0; i--){
        suf[i] = ( suf[i+1] * ((n - i) % MOD) ) % MOD;
    }
    long long res = 0;
    // Lagrange interpolation formula:
    // For j = 0,..., m, the Lagrange basis is
    // L_j(n) = pre[j] * suf[j+1] * inv( fact[j] * fact[m-j] ) * (-1)^(m - j)
    // and then F(p,n) = sum_{j=0}^{m} y[j] * L_j(n).
    for (int j = 0; j <= m; j++){
        long long term = y[j];
        term = (term * pre[j]) % MOD;
        term = (term * suf[j+1]) % MOD;
        long long denom = (fact[j] * fact[m - j]) % MOD;
        if(((m - j) & 1) == 1) // multiply by (-1)^(m - j)
            denom = (MOD - denom) % MOD;
        term = (term * modinv(denom)) % MOD;
        res = (res + term) % MOD;
    }
    return res;
}
 
// --- Main ---
//
// We are given integers N and K. For a permutation P of {1,2,…,N} we defined
//   f(P) = ∑_{i=1}^{N-1} |P_i - P_{i+1}|^K.
// We showed that
//   S = ∑_{P} f(P) = 2*(N-1)! * [ N*(∑_{d=1}^{N-1} d^K) - (∑_{d=1}^{N-1} d^(K+1)) ]  mod MOD.
// Also, note that if (N-1)! has a factor of MOD then it is 0 mod MOD – that is,
// if N-1 ≥ MOD (i.e. if N ≥ MOD+1) then S ≡ 0.
 
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
 
    long long N, K;
    cin >> N >> K;
    // If (N-1)! contains MOD then S ≡ 0.
    if(N >= MOD + 1) {
        cout << 0 << "\n";
        return 0;
    }
    // Compute A = (N-1)! mod MOD.
    long long A = 1;
    // (If N is very small – e.g. up to about 10^6 – a simple loop is fast.)
    for (int i = 1; i < N; i++){
        A = (A * i) % MOD;
    }
    // We now need S1 = ∑_{d=1}^{N-1} d^K and S2 = ∑_{d=1}^{N-1} d^(K+1) modulo MOD.
    long long nVal = N - 1;
    long long S1 = powerSum(K, nVal);
    long long S2 = powerSum(K + 1, nVal);
    // Our inner bracket is: N*S1 - S2  (computed mod MOD)
    long long inner = ((N % MOD) * S1) % MOD;
    inner = (inner - S2) % MOD;
    if(inner < 0) inner += MOD;
    long long ans = (2 * A) % MOD;
    ans = (ans * inner) % MOD;
    cout << ans % MOD << "\n";
    return 0;
}
0