結果
| 問題 | No.295 hel__world | 
| コンテスト | |
| ユーザー |  gew1fw | 
| 提出日時 | 2025-06-12 13:11:25 | 
| 言語 | PyPy3 (7.3.15) | 
| 結果 | 
                                WA
                                 
                             | 
| 実行時間 | - | 
| コード長 | 4,153 bytes | 
| コンパイル時間 | 182 ms | 
| コンパイル使用メモリ | 82,896 KB | 
| 実行使用メモリ | 338,672 KB | 
| 最終ジャッジ日時 | 2025-06-12 13:14:33 | 
| 合計ジャッジ時間 | 11,741 ms | 
| ジャッジサーバーID (参考情報) | judge2 / judge4 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 4 | 
| other | AC * 44 WA * 9 | 
ソースコード
import sys
import math
from collections import defaultdict
def main():
    # Read S_alpha
    S_alpha = list(map(int, sys.stdin.readline().split()))
    T = sys.stdin.readline().strip()
    # Precompute factorial and log_fact up to 1e6
    max_r = 10**6
    factorial = [1] * (max_r + 1)
    log_fact = [0.0] * (max_r + 1)
    for i in range(1, max_r + 1):
        factorial[i] = factorial[i-1] * i
        log_fact[i] = log_fact[i-1] + math.log(i)
        if factorial[i] > (1 << 62):
            factorial[i] = (1 << 62) + 1  # To prevent overflow in exact computation
    # Process T into T_comp and runs
    if not T:
        print(0)
        return
    T_comp = []
    runs = []
    prev_char = T[0]
    count = 1
    for c in T[1:]:
        if c == prev_char:
            count += 1
        else:
            T_comp.append(prev_char)
            runs.append(count)
            prev_char = c
            count = 1
    T_comp.append(prev_char)
    runs.append(count)
    m = len(T_comp)
    # Check if each character in T_comp has at least 1 in S_alpha
    for c in T_comp:
        idx = ord(c) - ord('a')
        if S_alpha[idx] < 1:
            print(0)
            return
    # For each character in T_comp, collect their groups and run_lengths
    groups = list(zip(T_comp, runs))
    char_groups = defaultdict(list)
    for i in range(m):
        c, r = groups[i]
        char_groups[c].append((i, r))
    # Check if sum of required run_lengths for each character is <= S_alpha
    for c in char_groups:
        indices = [i for i, r in char_groups[c]]
        sum_r = sum(r for i, r in char_groups[c])
        idx = ord(c) - ord('a')
        if sum_r > S_alpha[idx]:
            print(0)
            return
    # Now, allocate the remaining characters for each character
    new_x = []
    for i in range(m):
        new_x.append(groups[i][1])  # start with run_length
    for c in char_groups:
        idx_c = ord(c) - ord('a')
        group_info = char_groups[c]
        sum_r = sum(r for i, r in group_info)
        available = S_alpha[idx_c]
        rem = available - sum_r
        if rem < 0:
            print(0)
            return
        if rem == 0:
            continue
        # Extract the run_lengths and their indices
        run_lengths = [r for i, r in group_info]
        total_r = sum_r
        num_groups = len(run_lengths)
        # Compute k_i for each group
        k_list = []
        for r in run_lengths:
            k = (rem * r) // total_r
            k_list.append(k)
        rem_remaining = rem - sum(k_list)
        # Sort the groups by run_length descending, then index ascending
        sorted_group_info = sorted(group_info, key=lambda x: (-x[1], x[0]))
        sorted_indices = [i for i, r in sorted_group_info]
        # Distribute rem_remaining
        for i in range(rem_remaining):
            k_list[i] += 1
        # Update new_x for each group
        for i in range(num_groups):
            original_index = sorted_group_info[i][0]
            r = sorted_group_info[i][1]
            new_x[original_index] = r + k_list[i]
    # Now compute the product
    log_sum = 0.0
    for i in range(m):
        x = new_x[i]
        r = groups[i][1]
        if x < r:
            print(0)
            return
        # Compute log(C(x, r))
        current_log = 0.0
        for k in range(r):
            current_log += math.log(x - k)
        current_log -= log_fact[r]
        log_sum += current_log
        if log_sum > math.log(1 << 62):
            print("hel")
            return
    # Now compute the product exactly
    product = 1
    for i in range(m):
        x = new_x[i]
        r = groups[i][1]
        if x < r:
            print(0)
            return
        # Compute C(x, r)
        numerator = 1
        for k in range(r):
            numerator *= (x - k)
            if numerator > (1 << 62):
                print("hel")
                return
        denominator = factorial[r]
        term = numerator // denominator
        product *= term
        if product > (1 << 62):
            print("hel")
            return
    print(product)
if __name__ == "__main__":
    main()
            
            
            
        