結果
| 問題 |
No.295 hel__world
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 22:49:03 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,968 bytes |
| コンパイル時間 | 144 ms |
| コンパイル使用メモリ | 81,676 KB |
| 実行使用メモリ | 293,112 KB |
| 最終ジャッジ日時 | 2025-04-15 22:51:13 |
| 合計ジャッジ時間 | 6,470 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 46 WA * 7 |
ソースコード
import sys
import math
def main():
S_alpha = list(map(int, sys.stdin.readline().split()))
T = sys.stdin.readline().strip()
# Process T into T_compressed and groups
if not T:
print(0)
return
T_compressed = []
prev = T[0]
T_compressed.append(prev)
for c in T[1:]:
if c != prev:
T_compressed.append(c)
prev = c
groups = []
prev = T[0]
count = 1
for c in T[1:]:
if c == prev:
count +=1
else:
groups.append((prev, count))
prev = c
count = 1
groups.append((prev, count))
# Check if groups match T_compressed
if len(groups) != len(T_compressed):
print(0)
return
for i in range(len(groups)):
if groups[i][0] != T_compressed[i]:
print(0)
return
# Check for each character c, sum of count_i <= S_alpha[c]
from collections import defaultdict
sum_counts = defaultdict(int)
for c, cnt in groups:
sum_counts[c] += cnt
for c in sum_counts:
idx = ord(c) - ord('a')
if sum_counts[c] > S_alpha[idx]:
print(0)
return
# Prepare runs for each character
runs = defaultdict(list)
for i, (c, cnt) in enumerate(groups):
runs[c].append( (i, cnt) ) # (position in groups, count_i)
# For each character c, compute R_c and distribute x_i
x = [0] * len(groups)
for c in runs:
idx_list = runs[c]
sum_count = sum(cnt for pos, cnt in idx_list)
c_idx = ord(c) - ord('a')
R_c = S_alpha[c_idx] - sum_count
if R_c <0:
print(0)
return
# Distribute R_c
sum_cnt = sum(cnt for pos, cnt in idx_list)
if sum_cnt ==0:
continue
sorted_runs = sorted(idx_list, key=lambda x: -x[1])
cnt_list = [cnt for pos, cnt in sorted_runs]
positions = [pos for pos, cnt in sorted_runs]
total = sum_cnt
# Allocate x_i = (cnt_i * R_c) // total
x_list = []
for cnt_i in cnt_list:
xi = (cnt_i * R_c) // total
x_list.append(xi)
remaining = R_c - sum(x_list)
# Distribute remaining
for i in range(len(x_list)):
if remaining <=0:
break
x_list[i] +=1
remaining -=1
# Update x for each position
for i in range(len(sorted_runs)):
pos = positions[i]
x[pos] = x_list[i]
# Compute l_i = count_i + x_i for each group
product = 1
hel = 2**62
log_product = 0.0
hel_log = math.log(hel)
for i in range(len(groups)):
c, cnt = groups[i]
l_i = cnt + x[i]
if cnt ==0:
continue
if l_i < cnt:
print(0)
return
# Compute log(comb(l_i, cnt))
log_comb = math.lgamma(l_i +1) - math.lgamma(l_i - cnt +1) - math.lgamma(cnt +1)
log_product += log_comb
if log_product > hel_log + 1e-12:
print("hel")
return
# Now compute the actual product if needed
product = 1
for i in range(len(groups)):
c, cnt = groups[i]
l_i = cnt + x[i]
if cnt ==0:
comb =1
else:
comb = 1
for d in range(cnt):
term = l_i - d
comb *= term
if comb > hel:
print("hel")
return
# Compute factorial(cnt)
fact = 1
for d in range(1, cnt+1):
fact *=d
comb //= fact
if comb > hel:
print("hel")
return
product *= comb
if product > hel:
print("hel")
return
print(product)
if __name__ == "__main__":
main()
lam6er