結果

問題 No.1889 K Consecutive Ks (Hard)
ユーザー pockynypockyny
提出日時 2024-09-21 18:44:10
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 157 ms / 6,000 ms
コード長 8,028 bytes
コンパイル時間 3,704 ms
コンパイル使用メモリ 151,356 KB
実行使用メモリ 26,944 KB
最終ジャッジ日時 2024-09-21 18:44:18
合計ジャッジ時間 7,183 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 12 ms
15,232 KB
testcase_01 AC 13 ms
15,232 KB
testcase_02 AC 155 ms
26,944 KB
testcase_03 AC 14 ms
15,360 KB
testcase_04 AC 14 ms
15,232 KB
testcase_05 AC 14 ms
15,232 KB
testcase_06 AC 14 ms
15,232 KB
testcase_07 AC 15 ms
15,232 KB
testcase_08 AC 15 ms
15,232 KB
testcase_09 AC 143 ms
26,748 KB
testcase_10 AC 143 ms
26,712 KB
testcase_11 AC 49 ms
18,160 KB
testcase_12 AC 80 ms
21,060 KB
testcase_13 AC 82 ms
20,772 KB
testcase_14 AC 85 ms
20,976 KB
testcase_15 AC 151 ms
26,388 KB
testcase_16 AC 157 ms
26,896 KB
testcase_17 AC 82 ms
20,776 KB
testcase_18 AC 150 ms
26,808 KB
testcase_19 AC 151 ms
26,772 KB
testcase_20 AC 155 ms
26,940 KB
testcase_21 AC 155 ms
26,944 KB
testcase_22 AC 150 ms
26,816 KB
testcase_23 AC 78 ms
21,288 KB
testcase_24 AC 154 ms
26,816 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// これはつけてると壊れることがある (yukicoderなど)
// #pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")

#include <iostream>
#include <vector>
#include <cassert>
#include <atcoder/modint>
#include <atcoder/convolution>

using namespace std;
using namespace atcoder;
using mint = modint998244353;
// verify: https://judge.yosupo.jp/submission/202428
// もっと早くできるらしい https://paper.dropbox.com/doc/fps-EoHXQDZxfduAB8wD1PMBW
// https://qiita.com/Suu0313/items/69253873a8397323b376
vector<mint> fps_inv(vector<mint> f,int deg = -1){
    if(deg==-1) deg = f.size() - 1;
    assert(f.size());
    assert(f[0].val()!=0);
    vector<mint> g = {(mint)1/f[0]};
    vector<mint> ff = {f[0]};
    for(int i=1;i<=deg;i<<=1){
        // iでmod x^(2i)を計算する
        for(int j=ff.size();j<min((int)f.size(),2*i);j++){
            ff.push_back(f[j]);
        } 
        vector<mint> h = convolution(g,ff);
        h.resize(2*i);
        for(mint &u:h) u = -u;
        h[0] += 2;
        g = convolution(g,h);
        g.resize(2*i);
    }
    g.resize(deg + 1);
    return g;
}

vector<mint> fps_derivative(vector<mint> &f){
    vector<mint> ret;
    for(int i=1;i<f.size();i++){
        ret.push_back(f[i]*i);
    }
    if(f.size()==1) ret.push_back(0);
    return ret;
}

// get integral f(x) dx, s.t. const = 0;
vector<mint> fps_integral(vector<mint> &f){
    vector<mint> g(f.size());
    for(int i=g.size() - 1;i>=1;i--){
        g[i] = f[i - 1]/(mint)(i);
    }
    g[0] = 0;
    return g;
}

// get log(f) = -sum (1 - f)^n/n 
vector<mint> fps_log(vector<mint> f,int deg = -1){
    if(deg==-1) deg = f.size() - 1;
    assert(f[0].val()==1);
    vector<mint> g = convolution(fps_derivative(f),fps_inv(f,deg));
    g.resize(deg + 1);
    return fps_integral(g);
}

// get exp(f) = sum f^i/i!
vector<mint> fps_exp(vector<mint> f,int deg = -1){
    if(deg==-1) deg = f.size() - 1;
    assert(f.size());
    assert(f[0].val()==0);
    vector<mint> g = {(mint)1};
    for(int i=1;i<=deg;i<<=1){
        vector<mint> h(2*i);
        h[0] = 1;
        vector<mint> log_g = fps_log(g,h.size() - 1);
        for(int j=0;j<h.size();j++) h[j] -= log_g[j];
        for(int j=0;j<min(f.size(),h.size());j++) h[j] += f[j];
        g = convolution(g,h);
        g.resize(2*i);
    }
    g.resize(deg + 1);
    return g;
}

using ll = long long;
const int MX = 1000010;
mint f[MX],inv[MX],fi[MX];
constexpr ll mod = 998244353;
void solve(){
    inv[1] = 1;
    for(int i=2;i<MX;i++){
        inv[i] = mod - (mod/i)*inv[mod%i];
    }
    f[0] = fi[0] = 1;
    for(int i=1;i<MX;i++){
        f[i] = f[i-1]*i;
        fi[i] = fi[i-1]*inv[i];
    }
}

// input f(x) output g(x) = f(x + d)
// 階乗の逆元の計算が必須
vector<mint> shift(vector<mint> ff,int d){
    int i,len = ff.size();
    vector<mint> f1(len + 1),f2(len + 1);
    mint x = 1;
    for(i=0;i<len;i++){
        f1[i] = ff[i]*f[i];
        f2[len - i] = x*fi[i];
        x *= d;
    }
    vector<mint> f3 = convolution(f1,f2);
    vector<mint> ret;
    for(i=len;i<2*len;i++){
        ret.push_back(f3[i]*fi[i - len]);
    }
    return ret;
}

// input: s[0] + s[1]x + ... s[n - 1]x^n - 1
// output: (c[0] - c[1] - .. c[d])s = 0かつc[0] = 1なるcを存在すればどれか一つ
// O(n^2)-time
// copied from: https://nyaannyaan.github.io/library/fps/berlekamp-massey.hpp
// verified: https://judge.yosupo.jp/submission/199848
// algorithmの気持ち
// deg(c) = kの時に s[i] + Σ_{j<k} c[j]*s[i - 1 - j] を計算する
// 0ならOKで、0じゃないときは、1個前のcの線形結合で補正項を入れる
// 今までは成功していてたので、1個前のcの線形結合で補正したら、自分より下はつじつまが合うし、補正を入れるので徐々にあっていく
template <typename mint>
vector<mint> BerlekampMassey(const vector<mint> &s) {
    const int N = (int)s.size();
    vector<mint> b, c;
    b.reserve(N + 1);
    c.reserve(N + 1);
    b.push_back(mint(1));
    c.push_back(mint(1));
    mint y = mint(1);
    for(int ed = 1; ed <= N; ed++){
        int l = int(c.size()), m = int(b.size());
        mint x = 0;
        for (int i = 0; i < l; i++) x += c[i] * s[ed - l + i];
        b.emplace_back(mint(0));
        m++;
        if(x == mint(0)) continue;
        mint freq = x / y;
        if(l < m){
            auto tmp = c;
            c.insert(begin(c), m - l, mint(0));
            for (int i = 0; i < m; i++) c[m - 1 - i] -= freq * b[m - 1 - i];
            b = tmp;
            y = x;
        }else{
            for (int i = 0; i < m; i++) c[l - 1 - i] -= freq * b[m - 1 - i];
        }
    }
    reverse(begin(c), end(c));
    return c;
}

// BostanMori
// input: P(x)/Q(x),deg(P(x))<deg(Q(x)),
// output: [x^N](P(x)/Q(x)) 
// O(dlog(d)logN), (d := deg(Q(x)))
// TODO: MSB-firstにするともっと早いらしい
// verified: https://judge.yosupo.jp/submission/199853
mint BostanMori(vector<mint> P,vector<mint> Q,ll N){
    while(N){
        vector<mint> _Q(Q.size());
        for(int i=0;i<Q.size();i++) _Q[i] = i&1 ? -Q[i] : Q[i];
        vector<mint> nP = convolution(P,_Q);
        vector<mint> nQ = convolution(Q,_Q);
        Q.resize(nQ.size()/2 + 1);
        for(int i=0;i<nQ.size();i+=2) Q[i/2] = nQ[i];
        if(N&1){
            P.resize(nP.size()/2);
            for(int i=1;i<nP.size();i+=2) P[i/2] = nP[i];
        }else{
            P.resize((nP.size() + 1)/2);
            for(int i=0;i<nP.size();i+=2) P[i/2] = nP[i];
        }
        N /= 2;
    }
    return P[0]/Q[0];
}

mint pw(mint a,ll x){
    mint ret = 1;
    while(x){
        if(x&1) (ret *= a);
        (a *= a); x /= 2;
    }
    return ret;
}

// f^k mod x^N をsparceなfに対して計算
// [x^0]f \neq 0 を仮定
// fの非零項がm個で、O(mN) になる
// verified: https://atcoder.jp/contests/nadafes2022_day1/submissions/52160796
vector<mint> sparse_pow(vector<mint> &f,int k,int N){
    assert(f.size());
    assert(f[0]!=0);
    mint iv = pw(f[0],mod - 2);
    vector<pair<int,mint>> ff;
    for(int i=1;i<f.size();i++){
        if(f[i].val()) ff.push_back({i,f[i]});
    }
    vector<mint> ret(N);
    ret[0] = pw(f[0],k);
    for(int i=1;i<N;i++){
        for(int j=0;j<ff.size();j++){
            int deg = ff[j].first;
            if(deg>i) break;
            ret[i] += ff[j].second*deg*ret[i - deg];
        }
        ret[i] *= k;
        for(int j=0;j<ff.size();j++){
            int deg = ff[j].first;
            if(deg>i) break;
            ret[i] -= ff[j].second*(i - deg)*ret[i - deg];
        }
        ret[i] *= inv[i];
    }
    return ret;
}

template <typename = mint>
// d項の漸化式で、最初のd項と漸化式cが与えられたときに、k項目を求める
// cはa_i = c_0a_{i - 1} + c_1a_{i - 2} + ... と与える
// verify: https://judge.yosupo.jp/submission/213985
mint kthTerm(vector<mint> &ini,vector<mint> &c,ll k){
    int d = ini.size();
    assert(c.size()==d);
    vector<mint> _c(d + 1);
    for(int i=0;i<c.size();i++) _c[i + 1] = -c[i];
    _c[0] = 1;
    vector<mint> P = convolution(_c,ini);
    P.resize(d);
    return BostanMori(P,_c,k);
}

template <typename = mint>
// d項の漸化式で、最初の2d項が与えられたときに、k項目を求める (例: Fibonacchiなら前4項)
mint kthTermBySequence(vector<mint> a,ll k){
    vector<mint> c = BerlekampMassey(a);
    assert(c.size());
    // cが分母の形で変えるので微調整
    c.erase(c.begin());
    for(mint &u:c) u *= -1;
    int d = c.size();
    a.resize(d);
    return kthTerm(a,c,k);
}

int main(){
    int i,j,n,m; cin >> n >> m;
    mint ans = pw((mint)m,n);
    vector<mint> f(n + 1);
    for(i=1;i<=m;i++){
        // (x - x^i)(1 + x^i + x^2i + ... ) をfに足していく
        for(j=0;j<=n;j+=i){
            if(j + 1<=n) f[j + 1]++;
            if(j + i<=n) f[j + i]--;
        }
    }
    for(i=1;i<=n;i++) f[i] *= -1;
    f[0] = 1;
    ans -= fps_inv(f)[n];
    cout << ans.val() << "\n";
}
0