結果
| 問題 | No.665 Bernoulli Bernoulli |
| コンテスト | |
| ユーザー |
tonegawa
|
| 提出日時 | 2020-10-10 11:07:36 |
| 言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 4 ms / 2,000 ms |
| コード長 | 4,539 bytes |
| 記録 | |
| コンパイル時間 | 4,061 ms |
| コンパイル使用メモリ | 152,072 KB |
| 最終ジャッジ日時 | 2025-01-15 06:15:44 |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 15 |
ソースコード
#include <iostream>
#include <string>
#include <vector>
#include <array>
#include <queue>
#include <deque>
#include <algorithm>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <functional>
#include <cassert>
#include <iomanip>
#define vll vector<ll>
#define vvvl vector<vvl>
#define vvl vector<vector<ll>>
#define VV(a, b, c, d) vector<vector<d>>(a, vector<d>(b, c))
#define VVV(a, b, c, d) vector<vvl>(a, vvl(b, vll (c, d)));
#define re(c, b) for(ll c=0;c<b;c++)
#define all(obj) (obj).begin(), (obj).end()
typedef long long int ll;
typedef long double ld;
using namespace std;
ll mpow(ll a, ll b, ll p){
ll ret = 1, num = a;
while(b>0){
if(b%2) ret = (ret*num)%p;
num = (num*num)%p;
b /= 2;
}
return ret;
}
struct MontgomeryReduction{
uint64_t MOD;
uint64_t NEG_INV;
uint64_t R2;
MontgomeryReduction(uint64_t MOD_): MOD(MOD_){
NEG_INV = 0;
uint64_t s = 1, t = 0;
for(int i=0;i<64;i++){
if (~t & 1) {
t += MOD;
NEG_INV += s;
}
t >>= 1;
s <<= 1;
}
R2 = ((uint64_t)1<<32) % MOD;
R2 = R2 * R2 % MOD;
}
// return x * R % MOD;
inline uint64_t generate(uint64_t x) const{
//assert(x < MOD);
return reduce(x * R2);
}
// return x / R % MOD;
inline uint64_t reduce(uint64_t x) const{
//assert(x < (MOD * MOD));
x = (x + ((uint32_t)x * (uint32_t)NEG_INV) * MOD) >> 32;
return x < MOD? x: x-MOD;
}
};
ll inner_adjacentInterpolation(const vector<uint64_t> &y, uint64_t p, uint64_t MOD, const vector<uint64_t> &inv){
int n = y.size();
MontgomeryReduction mr(MOD);
vector<uint64_t> finv(n+1);
uint64_t one = mr.generate(1);
finv[0] = finv[1] = one;
for(int i=2;i<=n;i++) finv[i] = mr.reduce(finv[i-1]*inv[i]);
uint64_t M = one, ret = 0;
for(int i=0;i<n;i++){
if(i==p) return y[p];
uint64_t diff = p - i;
M = mr.reduce(M * mr.generate(diff));
}
M = mr.reduce(M);
uint64_t minus = mr.generate(MOD-1);
for(int i=0;i<n;i++){
uint64_t iQ = mr.reduce(finv[i] * finv[n-1-i]);
if((n-i-1)&1) iQ = mr.reduce(iQ * minus);
uint64_t iC = mr.reduce(mr.generate(y[i]) * iQ);
uint64_t diff = p - i;
uint64_t mul;
if(diff<=n) mul = mr.reduce(mr.reduce(iC * inv[diff]));
else {
uint64_t cnt = MOD-2, mlt = mr.generate(diff), val = one;
while(cnt){
if(cnt&1) val = mr.reduce(val * mlt);
mlt = mr.reduce(mlt * mlt);
cnt >>= 1;
}
mul = mr.reduce(mr.reduce(iC * val));
}
ret += mul;
}
return ((ret%MOD) * M)%MOD;
}
ll exPowerSum(ll r, ll d, ll n, ll MOD){
if(n==0) return 0;
n--;
vector<uint64_t> y(d+2), z(d+2), tbl(d+2);
MontgomeryReduction m(MOD);
//d乗テーブル
ll cnt = d;
for(int i=0;i<d+2;i++) {
tbl[i] = m.generate(i);
z[i] = m.generate(1);
}
while(cnt){
if(cnt&1) {
for(int i=0;i<d+2;i++) {
z[i] = m.reduce(z[i] * tbl[i]);
tbl[i] = m.reduce(tbl[i] * tbl[i]);
}
}else{
for(int i=0;i<d+2;i++) tbl[i] = m.reduce(tbl[i] * tbl[i]);
}
cnt >>= 1;
}
r = (r%MOD + MOD)%MOD;
uint64_t tmp = 0, rs = m.generate(1), R = m.generate(r);
ll last = mpow(r, n%(MOD-1), MOD);
n %= MOD;
for(int i=0;i<d+2;i++){
tmp += m.reduce(m.reduce(rs * z[i]));
if(tmp >= MOD) tmp -= MOD;
y[i] = tmp;
rs = m.reduce(rs * R);
}
//z, tblを使い回す
z[0] = m.generate(1);
z[1] = m.generate((MOD - r)%MOD);
tbl[0] = m.generate(0);
tbl[1] = m.generate(1);
uint64_t comb = m.generate(1), c = 0;
for(int i=2;i<=d+1;i++) {
tbl[i] = m.reduce(m.generate(MOD - (MOD/i))*tbl[MOD%i]);
z[i] = m.reduce(z[i-1] * z[1]);
}
if(r==1) return inner_adjacentInterpolation(y, n, MOD, tbl);
for(ll i=0;i<d+1;i++){
comb = m.reduce(comb * m.reduce(m.generate(d+1-i)*tbl[i+1]));
uint64_t tmp = m.reduce(comb * m.reduce(z[d-i] * m.generate(y[i])));
c = (c + m.reduce(tmp));
}
c %= MOD;
ll di = (MOD+1-r);
if(di >= MOD) di -= MOD;
c = (c * mpow(mpow(di, d+1, MOD), MOD-2, MOD))%MOD;
uint64_t powerRinv = m.generate(1);
uint64_t rinv = m.generate(mpow(r, MOD-2, MOD));
for(int i=0;i<d+1;i++){
y[i] = (y[i] + MOD - c);
if(y[i] >= MOD) y[i] -= MOD;
y[i] = m.reduce(m.reduce(m.generate(y[i]) * powerRinv));
powerRinv = m.reduce(powerRinv * rinv);
}
y.pop_back();
ll ans = (last * inner_adjacentInterpolation(y, n, MOD, tbl))%MOD;
return (ans + c)%MOD;
}
int main(){
ll n, k;std::cin >> n >> k;
std::cout << exPowerSum(1, k, n+1, 1000000007) << '\n';
}
tonegawa