結果

問題 No.295 hel__world
ユーザー lam6er
提出日時 2025-04-16 15:58:38
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,968 bytes
コンパイル時間 369 ms
コンパイル使用メモリ 82,104 KB
実行使用メモリ 292,896 KB
最終ジャッジ日時 2025-04-16 16:00:42
合計ジャッジ時間 7,411 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 46 WA * 7
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math

def main():
    S_alpha = list(map(int, sys.stdin.readline().split()))
    T = sys.stdin.readline().strip()
    
    # Process T into T_compressed and groups
    if not T:
        print(0)
        return
    
    T_compressed = []
    prev = T[0]
    T_compressed.append(prev)
    for c in T[1:]:
        if c != prev:
            T_compressed.append(c)
            prev = c
    
    groups = []
    prev = T[0]
    count = 1
    for c in T[1:]:
        if c == prev:
            count +=1
        else:
            groups.append((prev, count))
            prev = c
            count = 1
    groups.append((prev, count))
    
    # Check if groups match T_compressed
    if len(groups) != len(T_compressed):
        print(0)
        return
    for i in range(len(groups)):
        if groups[i][0] != T_compressed[i]:
            print(0)
            return
    
    # Check for each character c, sum of count_i <= S_alpha[c]
    from collections import defaultdict
    sum_counts = defaultdict(int)
    for c, cnt in groups:
        sum_counts[c] += cnt
    
    for c in sum_counts:
        idx = ord(c) - ord('a')
        if sum_counts[c] > S_alpha[idx]:
            print(0)
            return
    
    # Prepare runs for each character
    runs = defaultdict(list)
    for i, (c, cnt) in enumerate(groups):
        runs[c].append( (i, cnt) )  # (position in groups, count_i)
    
    # For each character c, compute R_c and distribute x_i
    x = [0] * len(groups)
    for c in runs:
        idx_list = runs[c]
        sum_count = sum(cnt for pos, cnt in idx_list)
        c_idx = ord(c) - ord('a')
        R_c = S_alpha[c_idx] - sum_count
        if R_c <0:
            print(0)
            return
        
        # Distribute R_c
        sum_cnt = sum(cnt for pos, cnt in idx_list)
        if sum_cnt ==0:
            continue
        sorted_runs = sorted(idx_list, key=lambda x: -x[1])
        cnt_list = [cnt for pos, cnt in sorted_runs]
        positions = [pos for pos, cnt in sorted_runs]
        total = sum_cnt
        
        # Allocate x_i = (cnt_i * R_c) // total
        x_list = []
        for cnt_i in cnt_list:
            xi = (cnt_i * R_c) // total
            x_list.append(xi)
        remaining = R_c - sum(x_list)
        
        # Distribute remaining
        for i in range(len(x_list)):
            if remaining <=0:
                break
            x_list[i] +=1
            remaining -=1
        
        # Update x for each position
        for i in range(len(sorted_runs)):
            pos = positions[i]
            x[pos] = x_list[i]
    
    # Compute l_i = count_i + x_i for each group
    product = 1
    hel = 2**62
    log_product = 0.0
    hel_log = math.log(hel)
    for i in range(len(groups)):
        c, cnt = groups[i]
        l_i = cnt + x[i]
        if cnt ==0:
            continue
        if l_i < cnt:
            print(0)
            return
        
        # Compute log(comb(l_i, cnt))
        log_comb = math.lgamma(l_i +1) - math.lgamma(l_i - cnt +1) - math.lgamma(cnt +1)
        log_product += log_comb
        if log_product > hel_log + 1e-12:
            print("hel")
            return
    
    # Now compute the actual product if needed
    product = 1
    for i in range(len(groups)):
        c, cnt = groups[i]
        l_i = cnt + x[i]
        if cnt ==0:
            comb =1
        else:
            comb = 1
            for d in range(cnt):
                term = l_i - d
                comb *= term
                if comb > hel:
                    print("hel")
                    return
            # Compute factorial(cnt)
            fact = 1
            for d in range(1, cnt+1):
                fact *=d
            comb //= fact
            if comb > hel:
                print("hel")
                return
        product *= comb
        if product > hel:
            print("hel")
            return
    
    print(product)

if __name__ == "__main__":
    main()
0