結果

問題 No.681 Fractal Gravity Glue
ユーザー lam6er
提出日時 2025-04-09 21:00:56
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,790 bytes
コンパイル時間 179 ms
コンパイル使用メモリ 82,376 KB
実行使用メモリ 83,188 KB
最終ジャッジ日時 2025-04-09 21:01:47
合計ジャッジ時間 4,173 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 1 -- * 11
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**9 + 7

def multiply(A, B):
    result = [[0]*3 for _ in range(3)]
    for i in range(3):
        for j in range(3):
            for k in range(3):
                result[i][j] = (result[i][j] + A[i][k] * B[k][j]) % MOD
    return result

def matrix_pow(matrix, power):
    result = [[1 if i == j else 0 for j in range(3)] for i in range(3)]
    while power > 0:
        if power % 2 == 1:
            result = multiply(result, matrix)
        matrix = multiply(matrix, matrix)
        power //= 2
    return result

def compute_S(b, d):
    if b == 0:
        return 0
    if b == 1:
        return d % MOD
    
    a = d + 1
    b_coeff = d
    c_coeff = d
    
    transition = [
        [a, b_coeff, c_coeff],
        [0, 1, 1],
        [0, 0, 1]
    ]
    
    mat = matrix_pow(transition, b - 1)
    
    S_initial = d % MOD
    i_initial = 1 % MOD
    one_initial = 1 % MOD
    
    S = (mat[0][0] * S_initial + mat[0][1] * i_initial + mat[0][2] * one_initial) % MOD
    return S

def is_h_prev_larger_than_n(current_b, d, n):
    if d + 1 == 1:
        return False
    
    required_pow = current_b - 1
    if required_pow <= 0:
        return (1 > n + 1)
    
    max_pow = 0
    product = 1
    target = n + 1
    
    while product <= target and max_pow < required_pow:
        next_product = product * (d + 1)
        if next_product > target:
            break
        product = next_product
        max_pow += 1
    
    return (max_pow < required_pow)

def compute_h_prev(current_b, d, n):
    if current_b == 0:
        return 0
    if d + 1 == 1:
        return 0
    
    max_pow = 0
    product = 1
    target = n + 1
    
    while product <= target and max_pow < (current_b - 1):
        next_product = product * (d + 1)
        if next_product > target:
            break
        product = next_product
        max_pow += 1
    
    h_prev = product - 1 if max_pow == (current_b - 1) else (d + 1) ** (current_b - 1) - 1
    return h_prev

def A(n, current_b, d, memo_S):
    if current_b == 1:
        return min(n, d) % MOD
    
    if (current_b, d) in memo_S:
        h_prev = memo_S[(current_b, d)]
    else:
        h_prev_larger = is_h_prev_larger_than_n(current_b, d, n)
        if h_prev_larger:
            while current_b > 1 and is_h_prev_larger_than_n(current_b, d, n):
                current_b -= 1
            memo_S[(current_b, d)] = current_b
        else:
            memo_S[(current_b, d)] = current_b
    
    current_b = memo_S[(current_b, d)]
    
    h_prev = compute_h_prev(current_b, d, n)
    if h_prev > n:
        return A(n, current_b - 1, d, memo_S) % MOD
    
    sum0 = compute_S(current_b - 1, d)
    sum0 %= MOD
    
    n_remaining = n - h_prev
    len_pair = h_prev + 1
    full_pairs = min(n_remaining // len_pair, d)
    
    sum_pairs = (full_pairs * current_b) % MOD
    sum_pairs_g = (full_pairs * sum0) % MOD
    sum_pairs_total = (sum_pairs + sum_pairs_g) % MOD
    
    n_remaining -= full_pairs * len_pair
    
    if n_remaining > 0:
        sum_pairs_total = (sum_pairs_total + current_b) % MOD
        n_remaining -= 1
        
        if n_remaining > 0:
            sum_remaining = A(n_remaining, current_b - 1, d, memo_S)
            sum_pairs_total = (sum_pairs_total + sum_remaining) % MOD
    
    total = (sum0 + sum_pairs_total) % MOD
    return total

def main():
    import sys
    input = sys.stdin.read().split()
    n = int(input[0])
    b = int(input[1])
    d = int(input[2])
    
    if n == 0:
        S_b = compute_S(b, d)
        print(S_b % MOD)
        return
    
    if b == 0:
        print(0)
        return
    
    S_b = compute_S(b, d)
    
    memo_S = {}
    sum_n = A(n, b, d, memo_S)
    
    answer = (S_b - sum_n) % MOD
    print(answer)

if __name__ == '__main__':
    main()
0