結果

問題 No.2303 Frog on Grid
ユーザー Jeroen Op de BeekJeroen Op de Beek
提出日時 2023-05-12 22:40:01
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 3,970 bytes
コンパイル時間 3,860 ms
コンパイル使用メモリ 229,972 KB
実行使用メモリ 44,548 KB
最終ジャッジ日時 2024-05-06 12:43:08
合計ジャッジ時間 7,345 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 11 ms
12,800 KB
testcase_01 AC 17 ms
7,424 KB
testcase_02 TLE -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma") 
#include "bits/stdc++.h"
using namespace std;
#define all(x) begin(x),end(x)
template<typename A, typename B> ostream& operator<<(ostream &os, const pair<A, B> &p) { return os << '(' << p.first << ", " << p.second << ')'; }
template<typename T_container, typename T = typename enable_if<!is_same<T_container, string>::value, typename T_container::value_type>::type> ostream& operator<<(ostream &os, const T_container &v) { string sep; for (const T &x : v) os << sep << x, sep = " "; return os; }
#define debug(a) cerr << "(" << #a << ": " << a << ")\n";
typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef pair<int,int> pi;

const int mxN = 1<<19, oo = 1e9;
const long long MOD =  998244353;
class mint {
public:
    int d;
    mint () {d=0;}
    mint (long long _d) : d(_d%MOD){};
    mint operator*(const mint& o) const {
        return ((ll)d*o.d)%MOD;
    }
    mint operator+(const mint& o) const {
        long long tmp = d+o.d;
        if(tmp>=MOD) tmp-=MOD;
        return tmp;
    }
    mint operator-(const mint& o) const {
        long long tmp = d-o.d;
        if(tmp<0) tmp+=MOD;
        return tmp;
    }
    mint operator^(long long b) const {
        mint tmp = 1;
        mint power = *this;
        while(b) {
            if(b&1) {
                tmp = tmp*power;
            }
            power = power*power;
            b/=2;
        }
        return tmp;
    }
    mint operator/(const mint& o) {
        return *this * (o^(MOD-2));
    }
    bool operator==(const mint& o) {
        return d==o.d;
    }
    friend mint& operator+=(mint& a, const mint& o) {
        a.d+=o.d;
        if(a.d>=MOD) a.d-=MOD;
        return a;
    }
};
typedef mint cd;
void revperm(cd* in, int n) {
    for(int i=0,j=0;i<n;++i) {
        if(i<j) swap(in[i],in[j]);
		for(int k = n >> 1; (j ^= k) < k; k >>= 1);
    }
}
cd w[mxN+1]; // stores w^j for each j in [0,n-1]
void precomp() {
    w[0] = 1;
    int pw = (MOD-1)/mxN;
    w[1] = mint(3)^pw;
    for(int i= 2;i<=mxN;++i) {
        w[i] = w[i-1]*w[1];
    }
}
void fft(cd* in, int n, bool reverse=false) {
    int lg = __lg(n);
    assert(1<<lg == n);
    int stride = mxN/n;
    revperm(in,n);
    for(int s=1;s<=lg;++s) {
        int pw = 1<<s;
        int mstride = stride*(n>>s);
        for(int j=0;j<n;j+=pw) {
            // do FFT merging on out array
            cd* even = in+j, *odd = in+j+pw/2;
            for(int i=0;i<pw/2;++i) {
                cd& power = w[reverse?mxN-mstride*i:mstride*i];
                auto tmp = power*odd[i];
                odd[i] = even[i] - tmp;
                even[i] = even[i] + tmp;
            }
        }
    }
    if(reverse) {
        mint fac = mint(1)/n;
        for(int i=0;i<n;++i) in[i]=in[i]*fac;
    }
}
int total;

vector<cd> polymul(vector<cd> a, vector<cd> b) {
    int n = a.size(), m = b.size(), ptwo = 1;
    while(ptwo<(n+m)) ptwo*=2;
    a.resize(ptwo), b.resize(ptwo);
    fft(a.data(),ptwo); 
    fft(b.data(),ptwo);
    for(int i=0;i<ptwo;++i) 
        a[i] = a[i]*b[i];
    fft(a.data(),ptwo,true);
    a.resize(n+m-1);
    return a;
}
map<int,vector<mint>> dp;
int m;
vector<mint> fib;
vector<mint>& solve(int n) {
    if(n==0) {
        return fib;
    }
    if(dp.count(n)) return dp[n];
    int split = n/2;
    if(n%2==0) split=0;
    // jump of length 1
    auto& ans = dp[n];
    {
    auto& res = solve(split), &res2 = solve(n-split-1);
    ans = polymul(res,res2);
    }
    ans.resize(m+1);
    
    if(n-split-1>=1) {
        auto& res = solve(split), &res2 = solve(n-split-2);
        auto mul = polymul(res,res2);
        for(int i=0;i<=m;++i) ans[i]+=mul[i]*(1 + (n%2==1));
    }

    return ans;
}
int main() {
    precomp();
    fib = {1,1};
    fib.resize(mxN);
    for(int i=2;i<mxN;++i) fib[i]=fib[i-1]+fib[i-2];

    int n; cin >> n >> m;
    fib.resize(m+1);
    auto res = solve(n);
    
    cout << res[m].d << '\n';
    
}
0