結果
問題 |
No.295 hel__world
|
ユーザー |
![]() |
提出日時 | 2025-06-12 13:08:22 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,153 bytes |
コンパイル時間 | 180 ms |
コンパイル使用メモリ | 82,596 KB |
実行使用メモリ | 339,104 KB |
最終ジャッジ日時 | 2025-06-12 13:12:15 |
合計ジャッジ時間 | 11,614 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 44 WA * 9 |
ソースコード
import sys import math from collections import defaultdict def main(): # Read S_alpha S_alpha = list(map(int, sys.stdin.readline().split())) T = sys.stdin.readline().strip() # Precompute factorial and log_fact up to 1e6 max_r = 10**6 factorial = [1] * (max_r + 1) log_fact = [0.0] * (max_r + 1) for i in range(1, max_r + 1): factorial[i] = factorial[i-1] * i log_fact[i] = log_fact[i-1] + math.log(i) if factorial[i] > (1 << 62): factorial[i] = (1 << 62) + 1 # To prevent overflow in exact computation # Process T into T_comp and runs if not T: print(0) return T_comp = [] runs = [] prev_char = T[0] count = 1 for c in T[1:]: if c == prev_char: count += 1 else: T_comp.append(prev_char) runs.append(count) prev_char = c count = 1 T_comp.append(prev_char) runs.append(count) m = len(T_comp) # Check if each character in T_comp has at least 1 in S_alpha for c in T_comp: idx = ord(c) - ord('a') if S_alpha[idx] < 1: print(0) return # For each character in T_comp, collect their groups and run_lengths groups = list(zip(T_comp, runs)) char_groups = defaultdict(list) for i in range(m): c, r = groups[i] char_groups[c].append((i, r)) # Check if sum of required run_lengths for each character is <= S_alpha for c in char_groups: indices = [i for i, r in char_groups[c]] sum_r = sum(r for i, r in char_groups[c]) idx = ord(c) - ord('a') if sum_r > S_alpha[idx]: print(0) return # Now, allocate the remaining characters for each character new_x = [] for i in range(m): new_x.append(groups[i][1]) # start with run_length for c in char_groups: idx_c = ord(c) - ord('a') group_info = char_groups[c] sum_r = sum(r for i, r in group_info) available = S_alpha[idx_c] rem = available - sum_r if rem < 0: print(0) return if rem == 0: continue # Extract the run_lengths and their indices run_lengths = [r for i, r in group_info] total_r = sum_r num_groups = len(run_lengths) # Compute k_i for each group k_list = [] for r in run_lengths: k = (rem * r) // total_r k_list.append(k) rem_remaining = rem - sum(k_list) # Sort the groups by run_length descending, then index ascending sorted_group_info = sorted(group_info, key=lambda x: (-x[1], x[0])) sorted_indices = [i for i, r in sorted_group_info] # Distribute rem_remaining for i in range(rem_remaining): k_list[i] += 1 # Update new_x for each group for i in range(num_groups): original_index = sorted_group_info[i][0] r = sorted_group_info[i][1] new_x[original_index] = r + k_list[i] # Now compute the product log_sum = 0.0 for i in range(m): x = new_x[i] r = groups[i][1] if x < r: print(0) return # Compute log(C(x, r)) current_log = 0.0 for k in range(r): current_log += math.log(x - k) current_log -= log_fact[r] log_sum += current_log if log_sum > math.log(1 << 62): print("hel") return # Now compute the product exactly product = 1 for i in range(m): x = new_x[i] r = groups[i][1] if x < r: print(0) return # Compute C(x, r) numerator = 1 for k in range(r): numerator *= (x - k) if numerator > (1 << 62): print("hel") return denominator = factorial[r] term = numerator // denominator product *= term if product > (1 << 62): print("hel") return print(product) if __name__ == "__main__": main()