結果

問題 No.295 hel__world
ユーザー gew1fw
提出日時 2025-06-12 13:08:22
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,153 bytes
コンパイル時間 180 ms
コンパイル使用メモリ 82,596 KB
実行使用メモリ 339,104 KB
最終ジャッジ日時 2025-06-12 13:12:15
合計ジャッジ時間 11,614 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 44 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
from collections import defaultdict

def main():
    # Read S_alpha
    S_alpha = list(map(int, sys.stdin.readline().split()))
    T = sys.stdin.readline().strip()

    # Precompute factorial and log_fact up to 1e6
    max_r = 10**6
    factorial = [1] * (max_r + 1)
    log_fact = [0.0] * (max_r + 1)
    for i in range(1, max_r + 1):
        factorial[i] = factorial[i-1] * i
        log_fact[i] = log_fact[i-1] + math.log(i)
        if factorial[i] > (1 << 62):
            factorial[i] = (1 << 62) + 1  # To prevent overflow in exact computation

    # Process T into T_comp and runs
    if not T:
        print(0)
        return
    T_comp = []
    runs = []
    prev_char = T[0]
    count = 1
    for c in T[1:]:
        if c == prev_char:
            count += 1
        else:
            T_comp.append(prev_char)
            runs.append(count)
            prev_char = c
            count = 1
    T_comp.append(prev_char)
    runs.append(count)
    m = len(T_comp)

    # Check if each character in T_comp has at least 1 in S_alpha
    for c in T_comp:
        idx = ord(c) - ord('a')
        if S_alpha[idx] < 1:
            print(0)
            return

    # For each character in T_comp, collect their groups and run_lengths
    groups = list(zip(T_comp, runs))
    char_groups = defaultdict(list)
    for i in range(m):
        c, r = groups[i]
        char_groups[c].append((i, r))

    # Check if sum of required run_lengths for each character is <= S_alpha
    for c in char_groups:
        indices = [i for i, r in char_groups[c]]
        sum_r = sum(r for i, r in char_groups[c])
        idx = ord(c) - ord('a')
        if sum_r > S_alpha[idx]:
            print(0)
            return

    # Now, allocate the remaining characters for each character
    new_x = []
    for i in range(m):
        new_x.append(groups[i][1])  # start with run_length

    for c in char_groups:
        idx_c = ord(c) - ord('a')
        group_info = char_groups[c]
        sum_r = sum(r for i, r in group_info)
        available = S_alpha[idx_c]
        rem = available - sum_r
        if rem < 0:
            print(0)
            return
        if rem == 0:
            continue

        # Extract the run_lengths and their indices
        run_lengths = [r for i, r in group_info]
        total_r = sum_r
        num_groups = len(run_lengths)

        # Compute k_i for each group
        k_list = []
        for r in run_lengths:
            k = (rem * r) // total_r
            k_list.append(k)
        rem_remaining = rem - sum(k_list)

        # Sort the groups by run_length descending, then index ascending
        sorted_group_info = sorted(group_info, key=lambda x: (-x[1], x[0]))
        sorted_indices = [i for i, r in sorted_group_info]

        # Distribute rem_remaining
        for i in range(rem_remaining):
            k_list[i] += 1

        # Update new_x for each group
        for i in range(num_groups):
            original_index = sorted_group_info[i][0]
            r = sorted_group_info[i][1]
            new_x[original_index] = r + k_list[i]

    # Now compute the product
    log_sum = 0.0
    for i in range(m):
        x = new_x[i]
        r = groups[i][1]
        if x < r:
            print(0)
            return
        # Compute log(C(x, r))
        current_log = 0.0
        for k in range(r):
            current_log += math.log(x - k)
        current_log -= log_fact[r]
        log_sum += current_log
        if log_sum > math.log(1 << 62):
            print("hel")
            return

    # Now compute the product exactly
    product = 1
    for i in range(m):
        x = new_x[i]
        r = groups[i][1]
        if x < r:
            print(0)
            return
        # Compute C(x, r)
        numerator = 1
        for k in range(r):
            numerator *= (x - k)
            if numerator > (1 << 62):
                print("hel")
                return
        denominator = factorial[r]
        term = numerator // denominator
        product *= term
        if product > (1 << 62):
            print("hel")
            return

    print(product)

if __name__ == "__main__":
    main()
0