結果

問題 No.978 Fibonacci Convolution Easy
ユーザー tarattata1tarattata1
提出日時 2019-07-14 02:22:21
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 2 ms / 2,000 ms
コード長 3,331 bytes
コンパイル時間 313 ms
コンパイル使用メモリ 38,576 KB
実行使用メモリ 4,348 KB
最終ジャッジ日時 2023-10-19 00:55:36
合計ジャッジ時間 1,183 ms
ジャッジサーバーID
(参考情報)
judge13 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,348 KB
testcase_01 AC 1 ms
4,348 KB
testcase_02 AC 1 ms
4,348 KB
testcase_03 AC 1 ms
4,348 KB
testcase_04 AC 2 ms
4,348 KB
testcase_05 AC 2 ms
4,348 KB
testcase_06 AC 2 ms
4,348 KB
testcase_07 AC 1 ms
4,348 KB
testcase_08 AC 2 ms
4,348 KB
testcase_09 AC 2 ms
4,348 KB
testcase_10 AC 2 ms
4,348 KB
testcase_11 AC 2 ms
4,348 KB
testcase_12 AC 1 ms
4,348 KB
testcase_13 AC 1 ms
4,348 KB
testcase_14 AC 2 ms
4,348 KB
testcase_15 AC 2 ms
4,348 KB
testcase_16 AC 1 ms
4,348 KB
testcase_17 AC 1 ms
4,348 KB
testcase_18 AC 2 ms
4,348 KB
testcase_19 AC 2 ms
4,348 KB
testcase_20 AC 1 ms
4,348 KB
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main(int, char**)’:
main.cpp:100:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  100 |     scanf("%d%d", &n, &p);
      |     ~~~~~^~~~~~~~~~~~~~~~

ソースコード

diff #

#include <stdio.h>
#include <vector>
#pragma warning(disable:4996) 

typedef long long ll;
const long long MOD = 1000000007;
using namespace std;

ll mpow(ll x, ll n){ //x^n(mod M)
    ll ans = 1;
    while(n != 0){
        if(n&1) ans = ans*x % MOD;
        x = x*x % MOD;
        n = n >> 1;
    }
    return ans;
}

ll minv(ll x){
    return mpow( x, MOD-2 );
}

void mtxprd( int p, int q, int r, const ll* A, const ll* B, ll* C )
{
    int i, j, k;
    for(i=0; i<p; i++) {
        for(k=0; k<r; k++) {
            C[i*r+k] = 0;
            for(j=0; j<q; j++) {
                C[i*r+k] = (C[i*r+k] + A[i*q+j] * B[j*r+k])%MOD;
            }
        }
    } 
    return;
}
 
void mtxuni(int dim, ll* mtx){
    int i,j;
    for(i=0; i<dim; i++) {
        for(j=0; j<dim; j++) {
            mtx[i*dim+j]=(i==j? 1: 0);
        }
    }
    return;
}
 
void mtxcpy(int dim, const ll* mtx, ll* mtx2){
    int i,j;
    for(i=0; i<dim; i++) {
        for(j=0; j<dim; j++) {
            mtx2[i*dim+j]=mtx[i*dim+j];
        }
    }
    return;
}
 
void mtxpow(int dim, const ll* mtx, ll n, ll* mtx2){ //x^n(mod M)
    vector<ll> mtx0(dim*dim);
    vector<ll> mtxtmp(dim*dim);
    mtxcpy( dim, mtx, &mtx0[0] );
    mtxuni( dim, mtx2 );
    while(n != 0){        
        if(n&1) {
            mtxprd( dim, dim, dim, mtx2, &mtx0[0], &mtxtmp[0] );
            mtxcpy( dim, &mtxtmp[0], mtx2 );
        }
        mtxprd( dim, dim, dim, &mtx0[0], &mtx0[0], &mtxtmp[0] );
        mtxcpy( dim, &mtxtmp[0], &mtx0[0] );
        n = n >> 1;
    }
    return;
}

// <解法>
// 求める答えは、(Σa[n])^2 + (Σ(a[n]^2))/2
// この計算は、O(n)ならば簡単に求めることができるが、
// nが大きい場合について解くには何か工夫が必要。
// → 行列累乗を使う
//
// s[n]=Σ(k=1..n)a[k] を求めるには以下の式を使う
//   ┌      ┐ ┌           ┐┌      ┐
//   │s[n]  │ │ 1   p   1 ││s[n-1]│
//   │a[n]  │=│ 0   p   1 ││a[n-1]│
//   │a[n-1]│ │ 0   1   0 ││a[n-2]│
//   └      ┘ └           ┘└      ┘
//
// s2[n]=Σ(k=1..n)(a[k]^2) を求めるには以下の式を使う
//   ┌             ┐ ┌                ┐┌             ┐
//   │s2[n]        │ │ 1   p^2  2p  1 ││s2[n-1]      │
//   │a[n]  *a[n]  │=│ 0   p^2  2p  1 ││a[n-1]*a[n-1]│
//   │a[n]  *a[n-1]│ │ 0   p    1   0 ││a[n-1]*a[n-2]│
//   │a[n-1]*a[n-1]│ │ 0   1    0   0 ││a[n-2]*a[n-2]│
//   └             ┘ └                ┘└             ┘
//


int main(int argc, char* argv[])
{
    int n,p;
    scanf("%d%d", &n, &p);

    if(n==1) {
        printf("0\n"); return 0;
    }

    ll ans0, ans1;
    {
        ll a[3*3] = {1, p, 1, 0, p, 1, 0, 1, 0};
        ll x0[3] = {1, 1, 0};
        ll an[3*3];
        mtxpow( 3, a, n-2, an );
        ll x1[3];
        mtxprd( 3, 3, 1, an, x0, x1 );
        ans0 = x1[0];
    }
    
    {
        ll a[4*4] = {1, (ll)p*p%MOD, (ll)p*2%MOD, 1, 0, (ll)p*p%MOD, p*2%MOD, 1, 0, p, 1, 0, 0, 1, 0, 0};
        ll x0[4] = {1, 1, 0, 0};
        ll an[4*4];
        mtxpow( 4, a, n-2, an );
        ll x1[4];
        mtxprd( 4, 4, 1, an, x0, x1 );
        ans1 = x1[0];
    }

    ll ans = (ans0*ans0 + ans1)%MOD * minv(2) %MOD;
    printf("%lld\n", ans);
    return 0;
}

0