結果

問題 No.603 hel__world (2)
ユーザー lam6er
提出日時 2025-04-16 00:03:06
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,003 bytes
コンパイル時間 179 ms
コンパイル使用メモリ 81,928 KB
実行使用メモリ 365,292 KB
最終ジャッジ日時 2025-04-16 00:04:50
合計ジャッジ時間 11,738 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 26 WA * 4
権限があれば一括ダウンロードができます

ソースコード

diff #

def comb_mod(n, k, fact, inv_fact, MOD):
    if k < 0 or k > n:
        return 0
    res = 1
    while n > 0 or k > 0:
        ni = n % MOD
        ki = k % MOD
        if ki > ni:
            return 0
        res = res * (fact[ni] * inv_fact[ki] % MOD) * inv_fact[ni - ki] % MOD
        n = n // MOD
        k = k // MOD
    return res

def process_T(T):
    if not T:
        return "", []
    T_comp = []
    T_runs = []
    current_char = T[0]
    count = 1
    for c in T[1:]:
        if c == current_char:
            count += 1
        else:
            T_comp.append(current_char)
            T_runs.append((current_char, count))
            current_char = c
            count = 1
    T_comp.append(current_char)
    T_runs.append((current_char, count))
    return ''.join(T_comp), T_runs

def main():
    MOD = 10**6 + 3
    S_alpha = list(map(int, input().split()))
    T = input().strip()
    
    T_comp, T_runs = process_T(T)
    
    # Check if any character in T_comp has S_alpha < 1
    for c in set(T_comp):
        idx = ord(c) - ord('a')
        if S_alpha[idx] < 1:
            print(0)
            return
    
    # Precompute factorial and inverse factorial modulo MOD
    fact = [1] * MOD
    for i in range(2, MOD):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * MOD
    inv_fact[MOD-1] = pow(fact[MOD-1], MOD-2, MOD)
    for i in range(MOD-2, 0, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    from collections import defaultdict
    char_to_kinfo = defaultdict(list)
    for i in range(len(T_comp)):
        c = T_comp[i]
        k_i = T_runs[i][1]
        char_to_kinfo[c].append((i, k_i))
    
    result = 1
    for c in char_to_kinfo:
        runs_info = char_to_kinfo[c]
        k_list = [k for (idx, k) in runs_info]
        sum_k = sum(k_list)
        S_c = S_alpha[ord(c) - ord('a')]
        if sum_k > S_c:
            print(0)
            return
        e = S_c - sum_k
        m = len(k_list)
        sum_k_total = sum_k
        
        base_e = []
        remainder_i_list = []
        for k in k_list:
            base_e_i = (e * k) // sum_k_total
            remainder_i = (e * k) % sum_k_total
            base_e.append(base_e_i)
            remainder_i_list.append(remainder_i)
        
        sum_base = sum(base_e)
        remainder_total = e - sum_base
        
        if remainder_total > 0:
            indexed_remainders = [(remainder_i_list[i], i) for i in range(m)]
            indexed_remainders.sort(key=lambda x: (-x[0], x[1]))
            for i in range(remainder_total):
                idx = indexed_remainders[i][1]
                base_e[idx] += 1
        
        contribution = 1
        for i in range(m):
            k = k_list[i]
            e_i = base_e[i]
            l_i = k + e_i
            comb = comb_mod(l_i, k, fact, inv_fact, MOD)
            contribution = contribution * comb % MOD
        
        result = result * contribution % MOD
    
    print(result)

if __name__ == "__main__":
    main()
0