結果

問題 No.213 素数サイコロと合成数サイコロ (3-Easy)
ユーザー rpy3cpprpy3cpp
提出日時 2015-08-24 18:13:29
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 373 ms / 3,000 ms
コード長 4,976 bytes
コンパイル時間 71 ms
コンパイル使用メモリ 12,800 KB
実行使用メモリ 11,008 KB
最終ジャッジ日時 2024-07-18 13:09:14
合計ジャッジ時間 1,133 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 175 ms
11,008 KB
testcase_01 AC 373 ms
11,008 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def solve(N, P, C):
    mod = 10 ** 9 + 7
    ps = [2-2, 3-2, 5-2, 7-2, 11-2, 13-2]
    cs = [4-4, 6-4, 8-4, 9-4, 10-4, 12-4]
    distp = get_dist(ps, P)
    distc = get_dist(cs, C)
    dist = merge_dists(distp, distc)
    coefs = [0] * (2 * P + 4 * C - 1) + dist
    coefs.reverse()
    inits = set_inits(coefs, mod)
    inits_tricked = trick_inits(inits, coefs, mod)
    return LRS(inits_tricked, coefs, N - 1, mod)


def set_inits(coefs, mod):
    '''
    coefs を係数に持つ線形漸化式であらわされる数列 bi について、
    初項を b[0] = 1, b[-i] = 0 としたときに、第0項からはじまる最初の len(coefs)項を返す。
    b[k] は、サイコロの目の合計がちょうど k となる、目の出方の場合の数を表わす。
    '''
    inits = [1]
    n = len(coefs)
    for i in range(1, n):
        v = sum(a * c for a, c, in zip(inits, coefs[-i:])) % mod
        inits.append(v)
    return inits


def trick_inits(bs, coefs, mod):
    '''サイコロの目の合計が N 以上となる目の出方の場合の数を求めることが出来るように、
    漸化式の数列の初期値を細工する。
    k = len(coefs)

    LRS([bs, coefs, N - 1, mod]) = b[N - 1]
    LRS([0] + bs[:-1], coefs, N - 1, mod) = b[N - 2]
    LRS([0] * i + b[:k-i], coefs, N - 1, mod) = b[N - 1 - i]
    ...

    N以上となる目の出方の場合の数は、
      b[N - 1 - 0] * sum(coefs[:k - 0]) # N - 1 にいて、次に、1以上の目が出る。
    + b[N - 1 - 1] * sum(coefs[:k - 1]) # N - 2 にいて、次に、2以上の目が出る。
    + b[N - 1 - 2] * sum(coefs[:k - 2]) # N - 3 にいて、次に、3以上の目が出る。
    ...
    + b[N - 1 - (k-1)] * sum(coefs[:k - (k-1)]) # N - k にいて、次に、k以上の目が出る。

    LRSに渡す初項を
    sum(([0]*i + b[:k-i]) * sum(coefs[:k-i]) for i in range(k))
    とすれば、良い。(実際には、[0,1]+[3,5]=[3,6] といった計算はできないので、リストの要素毎に計算する)
    '''
    k = len(coefs)
    inits = [0] * k
    for i in range(k):
        bbs = [0] * i + bs[:k - i]
        tmp = sum(coefs[:k - i])
        for j in range(k):
            inits[j] += bbs[j] * tmp
    return inits


def get_dist(qs, Q):
    '''
    qs を合計Q個使った和が、何通りの作り方があるかを返す。
    dp[i][n][s]: qs[:i] までを合計n個使って合計sとなる組み合わせの数
    qs[i] = qi とすると
    dp[i + 1][0][0] = 1
    dp[i + 1][0][s] = 0, s > 0
    dp[i + 1][1][s] = 1 if s == qi else dp[i][1][s]
    dp[i + 1][n][s] = dp[i][n][s] + dp[i + 1][n - 1][s - qi]
    [i] を落とすと、
    dp[n][s] += dp[n-1][s-n] で更新する。
    '''
    len_dp = qs[-1] * Q + 1
    dp = [[0] * len_dp for n in range(Q + 1)]
    dp[0][0] = 1
    for q in qs:
        for n in range(1, Q + 1):
            current_dp = dp[n]
            prev_dp = dp[n - 1]
            for s in range(q, q * n + 1):
                current_dp[s] += prev_dp[s - q]
    return dp[Q]


def merge_dists(distp, distc):
    mod = 10 ** 9 + 7
    len_p = len(distp)
    len_c = len(distc)
    dist = [0] * (len_p + len_c - 1)
    for i, pi in enumerate(distp):
        for ij, cj in enumerate(distc, i):
            dist[ij] += pi * cj
            dist[ij] %= mod
    return dist


def poly_mult(poly1, poly2, f, mod):
    '''
    多項式 poly1 と 多項式 poly2 の積を 多項式 x^n - fi * x^i で除した余りを求める。
    poly1, poly2 は、最高次数が n - 1 の多項式をあらわし、0乗からn-1乗までの係数のリスト。
    f は、0乗からn-1乗までの係数のリスト。
    mod は、整数で、係数は、mod で除した余りを求める。
    '''
    n = len(f)
    poly_long = [0] * (2 * n - 1)
    for i, p1 in enumerate(poly1):
        for j, p2 in enumerate(poly2):
            poly_long[i + j] += p1 * p2
    for i in range(2 * n - 2, n - 1, -1):
        p = poly_long[i] % mod
        for j, fk in enumerate(f, i - n):
            poly_long[j] += p * fk
    poly = [p % mod for p in poly_long[:n]]
    return poly


def poly_pow(f, p, mod):
    n = len(f)
    polyR = [0] * n
    polyR[0] = 1
    poly = [0] * n
    poly[1] = 1
    while p:
        if p & 1: polyR = poly_mult(poly, polyR, f, mod)
        poly = poly_mult(poly, poly, f, mod)
        p >>= 1
    return polyR


def LRS(As, Cs, n, mod):
    ''' Linear Recursive Seuqence 線形漸化式の第 n 項の値(をmodで除した余り)を求める。
    As は、数列の初期値たち。Cs は、漸化式の係数。 As と Cs の長さは同じ。
    b(i) = As[i] (0 <= i < k)
    b(i + k) = sum(Cs[j]*b(i + j) for j in range(k))
    と表される数列 b(i) の第 n 項を求める。
    '''
    poly = poly_pow(Cs, n, mod)
    return sum(p * a for p, a in zip(poly, As)) % mod


if __name__ == '__main__':
    N, P, C = map(int, input().split())
    print(solve(N, P, C))
0