結果
| 問題 |
No.2883 K-powered Sum of Fibonacci
|
| ユーザー |
|
| 提出日時 | 2024-09-24 23:25:05 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 5 ms / 3,000 ms |
| コード長 | 3,231 bytes |
| コンパイル時間 | 995 ms |
| コンパイル使用メモリ | 78,152 KB |
| 最終ジャッジ日時 | 2025-02-24 12:07:53 |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 40 |
ソースコード
#include <iostream>
#include <vector>
using i32 = int;
using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
template <i32 MOD>
struct Mint {
i32 n;
constexpr Mint(i32 n = 0): n(n) {}
constexpr Mint operator-() const { return Mint(n ? MOD - n: 0); }
constexpr Mint &operator+=(const Mint &rhs){ n += rhs.n; if(n >= MOD) n -= MOD; return *this; }
constexpr Mint &operator-=(const Mint &rhs){ if(rhs.n > n) n += MOD; n -= rhs.n; return *this; }
constexpr Mint &operator*=(const Mint &rhs){ n = (i64) n * rhs.n % MOD; return *this; }
constexpr Mint inv() const {
i32 x = MOD;
i32 y = n;
i32 b = 0, d = 1;
while(y){
i32 q = x / y;
x = x % y;
b -= q * d;
std::swap(x, y);
std::swap(b, d);
}
if(b < 0) b += MOD;
return b;
}
constexpr Mint &operator/=(const Mint &rhs){ n = (i64) n * rhs.inv().n % MOD; return *this; }
friend constexpr Mint operator+(const Mint &lhs, const Mint &rhs){ return Mint(lhs) += rhs; }
friend constexpr Mint operator-(const Mint &lhs, const Mint &rhs){ return Mint(lhs) -= rhs; }
friend constexpr Mint operator*(const Mint &lhs, const Mint &rhs){ return Mint(lhs) *= rhs; }
friend constexpr Mint operator/(const Mint &lhs, const Mint &rhs){ return Mint(lhs) /= rhs; }
friend constexpr bool operator==(const Mint &lhs, const Mint &rhs){ return lhs.n == rhs.n; }
friend constexpr bool operator!=(const Mint &lhs, const Mint &rhs){ return lhs.n != rhs.n; }
friend std::ostream &operator<<(std::ostream &os, const Mint &rhs){ return os << rhs.n; }
};
template <class T>
T modpow(T x, int n){
T r(1);
for(; n; n >>= 1){
if(n&1) r *= x;
x *= x;
}
return r;
}
constexpr u32 mod = 998244353;
using mint = Mint<mod>;
using poly = std::vector<mint>;
poly mult(const poly &lhs, const poly &rhs){
poly r(lhs.size() + rhs.size() - 1);
for(u32 i = 0; i < lhs.size(); i++){
for(u32 j = 0; j < rhs.size(); j++){
r[i+j] += lhs[i] * rhs[j];
}
}
return r;
}
void printv(const poly a){
for(u32 i = 0; i < a.size(); i++) std::cout << a[i] << ' ';
std::cout << std::endl;
}
poly denom(u32 k, mint e1, mint e2){
std::vector<mint> fib(k+2);
fib[0] = 0;
fib[1] = 1;
for(u32 i = 2; i < fib.size(); i++) fib[i] = e1 * fib[i-1] - e2 * fib[i-2];
poly den(k+2);
den[0] = 1;
mint c = 1;
for(u32 i = 1; i < den.size(); i++, c *= e2) den[i] = -c * den[i-1] * fib[k+2-i] / fib[i];
den = mult(den, { 1, mod-1 });
return den;
}
int main(){
u64 n;
u32 k;
std::cin >> n >> k;
std::vector<mint> fib(k+2);
fib[0] = 0;
fib[1] = 1;
for(u32 i = 2; i < fib.size(); i++) fib[i] = fib[i-1] + fib[i-2];
poly den = denom(k, 1, mod-1);
poly num(k+2);
for(u32 i = 1; i < num.size(); i++) num[i] = num[i-1] + modpow(fib[i], k);
num = mult(num, den);
num.resize(k+2);
mint e1 = 1;
mint e2 = mod-1;
for(;n; n>>=1){
poly dem = den;
for(u32 i = 1; i < dem.size(); i += 2) dem[i] = -dem[i];
num = mult(num, dem);
for(u32 i = n&1; i < num.size(); i += 2) num[i/2] = num[i];
num.resize(num.size()/2);
e1 = e1 * e1 - 2*e2;
e2 = e2 * e2;
den = denom(k, e1, e2);
}
std::cout << num[0] << std::endl;
return 0;
}