結果
| 問題 |
No.263 Common Palindromes Extra
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:20:08 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 5,836 bytes |
| コンパイル時間 | 256 ms |
| コンパイル使用メモリ | 82,180 KB |
| 実行使用メモリ | 283,176 KB |
| 最終ジャッジ日時 | 2025-05-14 13:21:12 |
| 合計ジャッジ時間 | 6,427 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 5 MLE * 3 -- * 4 |
ソースコード
import sys
import array
import gc
def solve():
S_str = sys.stdin.readline().strip()
T_str = sys.stdin.readline().strip()
class PalindromicTree:
def __init__(self, n_max_len, alphabet_size=26, char_offset=ord('A')):
# n_max_len is the maximum length of the string this tree will process
self.alphabet_size = alphabet_size
self.char_offset = char_offset
# s[0] is a dummy char, string is s[1...s_len]
self.s = array.array('b', [-1] * (n_max_len + 1))
self.s_len = 0
# Max nodes: n_max_len (distinct palindromes) + 2 roots
# Array sizes should accommodate this.
num_potential_nodes = n_max_len + 2
self.len = array.array('i', [0] * num_potential_nodes)
self.link = array.array('i', [0] * num_potential_nodes)
self.next_transitions = array.array('i', [0] * (num_potential_nodes * alphabet_size))
self.count = array.array('l', [0] * num_potential_nodes)
self.node_hash1 = array.array('L', [0] * num_potential_nodes)
self.node_hash2 = array.array('L', [0] * num_potential_nodes)
self.nodes_count = 0
self.last = 0
self.H_BASE1, self.H_MOD1 = 31, 10**9 + 7
self.H_BASE2, self.H_MOD2 = 37, 10**9 + 9
self.h_powers1 = array.array('L', [1] * (n_max_len + 1))
self.h_powers2 = array.array('L', [1] * (n_max_len + 1))
for i in range(1, n_max_len + 1):
self.h_powers1[i] = (self.h_powers1[i-1] * self.H_BASE1) % self.H_MOD1
self.h_powers2[i] = (self.h_powers2[i-1] * self.H_BASE2) % self.H_MOD2
self.nodes_count = 1
self.len[0] = -1
self.link[0] = 0
self.nodes_count = 2
self.len[1] = 0
self.link[1] = 0
self.last = 1
def _get_next(self, node_idx, char_code):
return self.next_transitions[node_idx * self.alphabet_size + char_code]
def _set_next(self, node_idx, char_code, next_node_idx):
self.next_transitions[node_idx * self.alphabet_size + char_code] = next_node_idx
def _get_char_code(self, char_str_val):
return ord(char_str_val) - self.char_offset
def add_char(self, char_val_str):
c = self._get_char_code(char_val_str)
self.s_len += 1
self.s[self.s_len] = c
cur = self.last
while self.s[self.s_len - self.len[cur] - 1] != c:
cur = self.link[cur]
if self._get_next(cur, c) == 0:
v_new = self.nodes_count
self.nodes_count += 1
self.len[v_new] = self.len[cur] + 2
val_c = c + 1
if self.len[v_new] == 1:
self.node_hash1[v_new] = val_c
self.node_hash2[v_new] = val_c
else:
len_p_cur = self.len[cur]
h1_cur = self.node_hash1[cur]
h2_cur = self.node_hash2[cur]
self.node_hash1[v_new] = val_c
self.node_hash1[v_new] = (self.node_hash1[v_new] + h1_cur * self.H_BASE1) % self.H_MOD1
self.node_hash1[v_new] = (self.node_hash1[v_new] + val_c * self.h_powers1[len_p_cur + 1]) % self.H_MOD1
self.node_hash2[v_new] = val_c
self.node_hash2[v_new] = (self.node_hash2[v_new] + h2_cur * self.H_BASE2) % self.H_MOD2
self.node_hash2[v_new] = (self.node_hash2[v_new] + val_c * self.h_powers2[len_p_cur + 1]) % self.H_MOD2
link_cand = self.link[cur]
while self.s[self.s_len - self.len[link_cand] - 1] != c:
link_cand = self.link[link_cand]
if self.len[v_new] == 1:
self.link[v_new] = 1
else:
self.link[v_new] = self._get_next(link_cand, c)
self._set_next(cur, c, v_new)
self.last = self._get_next(cur, c)
self.count[self.last] += 1
def finalize_counts(self):
for i in range(self.nodes_count - 1, 1, -1):
self.count[self.link[i]] += self.count[i]
def get_palindromes_list(self):
pal_list = []
for i in range(2, self.nodes_count):
if self.count[i] > 0:
h_pair = (self.node_hash1[i], self.node_hash2[i])
pal_list.append( (h_pair, self.count[i]) )
return pal_list
pt_S = PalindromicTree(len(S_str))
for char_s in S_str:
pt_S.add_char(char_s)
pt_S.finalize_counts()
list_S = pt_S.get_palindromes_list()
del pt_S
gc.collect()
pt_T = PalindromicTree(len(T_str))
for char_t in T_str:
pt_T.add_char(char_t)
pt_T.finalize_counts()
list_T = pt_T.get_palindromes_list()
del pt_T
gc.collect()
list_S.sort()
list_T.sort()
total_ans = 0
ptr_S, ptr_T = 0, 0
len_list_S, len_list_T = len(list_S), len(list_T)
while ptr_S < len_list_S and ptr_T < len_list_T:
hash_S_pair, count_S_val = list_S[ptr_S]
hash_T_pair, count_T_val = list_T[ptr_T]
if hash_S_pair == hash_T_pair:
total_ans += count_S_val * count_T_val
ptr_S += 1
ptr_T += 1
elif hash_S_pair < hash_T_pair:
ptr_S += 1
else:
ptr_T += 1
sys.stdout.write(str(total_ans) + "\n")
solve()
qwewe