結果

問題 No.665 Bernoulli Bernoulli
ユーザー tonegawatonegawa
提出日時 2020-10-10 11:07:36
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 4 ms / 2,000 ms
コード長 4,539 bytes
コンパイル時間 1,557 ms
コンパイル使用メモリ 123,028 KB
実行使用メモリ 4,384 KB
最終ジャッジ日時 2023-09-27 21:11:00
合計ジャッジ時間 2,268 ms
ジャッジサーバーID
(参考情報)
judge12 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,384 KB
testcase_01 AC 1 ms
4,380 KB
testcase_02 AC 2 ms
4,380 KB
testcase_03 AC 3 ms
4,380 KB
testcase_04 AC 4 ms
4,380 KB
testcase_05 AC 3 ms
4,380 KB
testcase_06 AC 3 ms
4,376 KB
testcase_07 AC 4 ms
4,376 KB
testcase_08 AC 3 ms
4,376 KB
testcase_09 AC 4 ms
4,376 KB
testcase_10 AC 3 ms
4,376 KB
testcase_11 AC 4 ms
4,380 KB
testcase_12 AC 4 ms
4,380 KB
testcase_13 AC 4 ms
4,376 KB
testcase_14 AC 3 ms
4,376 KB
testcase_15 AC 3 ms
4,380 KB
testcase_16 AC 3 ms
4,376 KB
testcase_17 AC 3 ms
4,376 KB
testcase_18 AC 3 ms
4,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <string>
#include <vector>
#include <array>
#include <queue>
#include <deque>
#include <algorithm>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <functional>
#include <cassert>
#include <iomanip>
#define vll vector<ll>
#define vvvl vector<vvl>
#define vvl vector<vector<ll>>
#define VV(a, b, c, d) vector<vector<d>>(a, vector<d>(b, c))
#define VVV(a, b, c, d) vector<vvl>(a, vvl(b, vll (c, d)));
#define re(c, b) for(ll c=0;c<b;c++)
#define all(obj) (obj).begin(), (obj).end()
typedef long long int ll;
typedef long double ld;
using namespace std;

ll mpow(ll a, ll b, ll p){
  ll ret = 1, num = a;
  while(b>0){
    if(b%2) ret = (ret*num)%p;
    num = (num*num)%p;
    b /= 2;
  }
  return ret;
}
struct MontgomeryReduction{
  uint64_t MOD;
  uint64_t NEG_INV;
  uint64_t R2;
  MontgomeryReduction(uint64_t MOD_): MOD(MOD_){
    NEG_INV = 0;
    uint64_t s = 1, t = 0;
    for(int i=0;i<64;i++){
      if (~t & 1) {
        t += MOD;
        NEG_INV += s;
      }
      t >>= 1;
      s <<= 1;
    }
    R2 = ((uint64_t)1<<32) % MOD;
    R2 = R2 * R2 % MOD;
  }
  // return x * R % MOD;
  inline uint64_t generate(uint64_t x) const{
    //assert(x < MOD);
    return reduce(x * R2);
  }
  // return x / R % MOD;
  inline uint64_t reduce(uint64_t x) const{
    //assert(x < (MOD * MOD));
    x = (x + ((uint32_t)x * (uint32_t)NEG_INV) * MOD) >> 32;
    return x < MOD? x: x-MOD;
  }
};

ll inner_adjacentInterpolation(const vector<uint64_t> &y, uint64_t p, uint64_t MOD, const vector<uint64_t> &inv){
  int n = y.size();
  MontgomeryReduction mr(MOD);
  vector<uint64_t> finv(n+1);
  uint64_t one = mr.generate(1);
  finv[0] = finv[1] = one;
  for(int i=2;i<=n;i++) finv[i] = mr.reduce(finv[i-1]*inv[i]);
  uint64_t M = one, ret = 0;
  for(int i=0;i<n;i++){
    if(i==p) return y[p];
    uint64_t diff = p - i;
    M = mr.reduce(M * mr.generate(diff));
  }
  M = mr.reduce(M);
  uint64_t minus = mr.generate(MOD-1);
  for(int i=0;i<n;i++){
    uint64_t iQ = mr.reduce(finv[i] * finv[n-1-i]);
    if((n-i-1)&1) iQ = mr.reduce(iQ * minus);
    uint64_t iC = mr.reduce(mr.generate(y[i]) * iQ);
    uint64_t diff = p - i;
    uint64_t mul;
    if(diff<=n) mul = mr.reduce(mr.reduce(iC * inv[diff]));
    else {
      uint64_t cnt = MOD-2, mlt = mr.generate(diff), val = one;
      while(cnt){
        if(cnt&1) val = mr.reduce(val * mlt);
        mlt = mr.reduce(mlt * mlt);
        cnt >>= 1;
      }
      mul = mr.reduce(mr.reduce(iC * val));
    }
    ret += mul;
  }
  return ((ret%MOD) * M)%MOD;
}

ll exPowerSum(ll r, ll d, ll n, ll MOD){
  if(n==0) return 0;
  n--;
  vector<uint64_t> y(d+2), z(d+2), tbl(d+2);
  MontgomeryReduction m(MOD);
  //d乗テーブル
  ll cnt = d;
  for(int i=0;i<d+2;i++) {
    tbl[i] = m.generate(i);
    z[i] = m.generate(1);
  }
  while(cnt){
    if(cnt&1) {
      for(int i=0;i<d+2;i++) {
        z[i] = m.reduce(z[i] * tbl[i]);
        tbl[i] = m.reduce(tbl[i] * tbl[i]);
      }
    }else{
      for(int i=0;i<d+2;i++) tbl[i] = m.reduce(tbl[i] * tbl[i]);
    }
    cnt >>= 1;
  }
  r = (r%MOD + MOD)%MOD;
  uint64_t tmp = 0, rs = m.generate(1), R = m.generate(r);

  ll last = mpow(r, n%(MOD-1), MOD);
  n %= MOD;

  for(int i=0;i<d+2;i++){
    tmp += m.reduce(m.reduce(rs * z[i]));
    if(tmp >= MOD) tmp -= MOD;
    y[i] = tmp;
    rs = m.reduce(rs * R);
  }
  //z, tblを使い回す
  z[0] = m.generate(1);
  z[1] = m.generate((MOD - r)%MOD);
  tbl[0] = m.generate(0);
  tbl[1] = m.generate(1);
  uint64_t comb = m.generate(1), c = 0;
  for(int i=2;i<=d+1;i++) {
    tbl[i] = m.reduce(m.generate(MOD - (MOD/i))*tbl[MOD%i]);
    z[i] = m.reduce(z[i-1] * z[1]);
  }
  if(r==1) return inner_adjacentInterpolation(y, n, MOD, tbl);
  for(ll i=0;i<d+1;i++){
    comb = m.reduce(comb * m.reduce(m.generate(d+1-i)*tbl[i+1]));
    uint64_t tmp = m.reduce(comb * m.reduce(z[d-i] * m.generate(y[i])));
    c = (c + m.reduce(tmp));
  }
  c %= MOD;
  ll di = (MOD+1-r);
  if(di >= MOD) di -= MOD;
  c = (c * mpow(mpow(di, d+1, MOD), MOD-2, MOD))%MOD;
  uint64_t powerRinv = m.generate(1);
  uint64_t rinv = m.generate(mpow(r, MOD-2, MOD));
  for(int i=0;i<d+1;i++){
    y[i] = (y[i] + MOD - c);
    if(y[i] >= MOD) y[i] -= MOD;
    y[i] = m.reduce(m.reduce(m.generate(y[i]) * powerRinv));
    powerRinv = m.reduce(powerRinv * rinv);
  }
  y.pop_back();
  ll ans = (last * inner_adjacentInterpolation(y, n, MOD, tbl))%MOD;
  return (ans + c)%MOD;
}

int main(){
  ll n, k;std::cin >> n >> k;
  std::cout << exPowerSum(1, k, n+1, 1000000007) << '\n';
}
0