結果
| 問題 |
No.603 hel__world (2)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:05:03 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,003 bytes |
| コンパイル時間 | 477 ms |
| コンパイル使用メモリ | 81,824 KB |
| 実行使用メモリ | 364,728 KB |
| 最終ジャッジ日時 | 2025-04-16 00:07:05 |
| 合計ジャッジ時間 | 12,173 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 26 WA * 4 |
ソースコード
def comb_mod(n, k, fact, inv_fact, MOD):
if k < 0 or k > n:
return 0
res = 1
while n > 0 or k > 0:
ni = n % MOD
ki = k % MOD
if ki > ni:
return 0
res = res * (fact[ni] * inv_fact[ki] % MOD) * inv_fact[ni - ki] % MOD
n = n // MOD
k = k // MOD
return res
def process_T(T):
if not T:
return "", []
T_comp = []
T_runs = []
current_char = T[0]
count = 1
for c in T[1:]:
if c == current_char:
count += 1
else:
T_comp.append(current_char)
T_runs.append((current_char, count))
current_char = c
count = 1
T_comp.append(current_char)
T_runs.append((current_char, count))
return ''.join(T_comp), T_runs
def main():
MOD = 10**6 + 3
S_alpha = list(map(int, input().split()))
T = input().strip()
T_comp, T_runs = process_T(T)
# Check if any character in T_comp has S_alpha < 1
for c in set(T_comp):
idx = ord(c) - ord('a')
if S_alpha[idx] < 1:
print(0)
return
# Precompute factorial and inverse factorial modulo MOD
fact = [1] * MOD
for i in range(2, MOD):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * MOD
inv_fact[MOD-1] = pow(fact[MOD-1], MOD-2, MOD)
for i in range(MOD-2, 0, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
from collections import defaultdict
char_to_kinfo = defaultdict(list)
for i in range(len(T_comp)):
c = T_comp[i]
k_i = T_runs[i][1]
char_to_kinfo[c].append((i, k_i))
result = 1
for c in char_to_kinfo:
runs_info = char_to_kinfo[c]
k_list = [k for (idx, k) in runs_info]
sum_k = sum(k_list)
S_c = S_alpha[ord(c) - ord('a')]
if sum_k > S_c:
print(0)
return
e = S_c - sum_k
m = len(k_list)
sum_k_total = sum_k
base_e = []
remainder_i_list = []
for k in k_list:
base_e_i = (e * k) // sum_k_total
remainder_i = (e * k) % sum_k_total
base_e.append(base_e_i)
remainder_i_list.append(remainder_i)
sum_base = sum(base_e)
remainder_total = e - sum_base
if remainder_total > 0:
indexed_remainders = [(remainder_i_list[i], i) for i in range(m)]
indexed_remainders.sort(key=lambda x: (-x[0], x[1]))
for i in range(remainder_total):
idx = indexed_remainders[i][1]
base_e[idx] += 1
contribution = 1
for i in range(m):
k = k_list[i]
e_i = base_e[i]
l_i = k + e_i
comb = comb_mod(l_i, k, fact, inv_fact, MOD)
contribution = contribution * comb % MOD
result = result * contribution % MOD
print(result)
if __name__ == "__main__":
main()
lam6er