結果
問題 |
No.263 Common Palindromes Extra
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:20:08 |
言語 | PyPy3 (7.3.15) |
結果 |
MLE
|
実行時間 | - |
コード長 | 5,836 bytes |
コンパイル時間 | 256 ms |
コンパイル使用メモリ | 82,180 KB |
実行使用メモリ | 283,176 KB |
最終ジャッジ日時 | 2025-05-14 13:21:12 |
合計ジャッジ時間 | 6,427 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 5 MLE * 3 -- * 4 |
ソースコード
import sys import array import gc def solve(): S_str = sys.stdin.readline().strip() T_str = sys.stdin.readline().strip() class PalindromicTree: def __init__(self, n_max_len, alphabet_size=26, char_offset=ord('A')): # n_max_len is the maximum length of the string this tree will process self.alphabet_size = alphabet_size self.char_offset = char_offset # s[0] is a dummy char, string is s[1...s_len] self.s = array.array('b', [-1] * (n_max_len + 1)) self.s_len = 0 # Max nodes: n_max_len (distinct palindromes) + 2 roots # Array sizes should accommodate this. num_potential_nodes = n_max_len + 2 self.len = array.array('i', [0] * num_potential_nodes) self.link = array.array('i', [0] * num_potential_nodes) self.next_transitions = array.array('i', [0] * (num_potential_nodes * alphabet_size)) self.count = array.array('l', [0] * num_potential_nodes) self.node_hash1 = array.array('L', [0] * num_potential_nodes) self.node_hash2 = array.array('L', [0] * num_potential_nodes) self.nodes_count = 0 self.last = 0 self.H_BASE1, self.H_MOD1 = 31, 10**9 + 7 self.H_BASE2, self.H_MOD2 = 37, 10**9 + 9 self.h_powers1 = array.array('L', [1] * (n_max_len + 1)) self.h_powers2 = array.array('L', [1] * (n_max_len + 1)) for i in range(1, n_max_len + 1): self.h_powers1[i] = (self.h_powers1[i-1] * self.H_BASE1) % self.H_MOD1 self.h_powers2[i] = (self.h_powers2[i-1] * self.H_BASE2) % self.H_MOD2 self.nodes_count = 1 self.len[0] = -1 self.link[0] = 0 self.nodes_count = 2 self.len[1] = 0 self.link[1] = 0 self.last = 1 def _get_next(self, node_idx, char_code): return self.next_transitions[node_idx * self.alphabet_size + char_code] def _set_next(self, node_idx, char_code, next_node_idx): self.next_transitions[node_idx * self.alphabet_size + char_code] = next_node_idx def _get_char_code(self, char_str_val): return ord(char_str_val) - self.char_offset def add_char(self, char_val_str): c = self._get_char_code(char_val_str) self.s_len += 1 self.s[self.s_len] = c cur = self.last while self.s[self.s_len - self.len[cur] - 1] != c: cur = self.link[cur] if self._get_next(cur, c) == 0: v_new = self.nodes_count self.nodes_count += 1 self.len[v_new] = self.len[cur] + 2 val_c = c + 1 if self.len[v_new] == 1: self.node_hash1[v_new] = val_c self.node_hash2[v_new] = val_c else: len_p_cur = self.len[cur] h1_cur = self.node_hash1[cur] h2_cur = self.node_hash2[cur] self.node_hash1[v_new] = val_c self.node_hash1[v_new] = (self.node_hash1[v_new] + h1_cur * self.H_BASE1) % self.H_MOD1 self.node_hash1[v_new] = (self.node_hash1[v_new] + val_c * self.h_powers1[len_p_cur + 1]) % self.H_MOD1 self.node_hash2[v_new] = val_c self.node_hash2[v_new] = (self.node_hash2[v_new] + h2_cur * self.H_BASE2) % self.H_MOD2 self.node_hash2[v_new] = (self.node_hash2[v_new] + val_c * self.h_powers2[len_p_cur + 1]) % self.H_MOD2 link_cand = self.link[cur] while self.s[self.s_len - self.len[link_cand] - 1] != c: link_cand = self.link[link_cand] if self.len[v_new] == 1: self.link[v_new] = 1 else: self.link[v_new] = self._get_next(link_cand, c) self._set_next(cur, c, v_new) self.last = self._get_next(cur, c) self.count[self.last] += 1 def finalize_counts(self): for i in range(self.nodes_count - 1, 1, -1): self.count[self.link[i]] += self.count[i] def get_palindromes_list(self): pal_list = [] for i in range(2, self.nodes_count): if self.count[i] > 0: h_pair = (self.node_hash1[i], self.node_hash2[i]) pal_list.append( (h_pair, self.count[i]) ) return pal_list pt_S = PalindromicTree(len(S_str)) for char_s in S_str: pt_S.add_char(char_s) pt_S.finalize_counts() list_S = pt_S.get_palindromes_list() del pt_S gc.collect() pt_T = PalindromicTree(len(T_str)) for char_t in T_str: pt_T.add_char(char_t) pt_T.finalize_counts() list_T = pt_T.get_palindromes_list() del pt_T gc.collect() list_S.sort() list_T.sort() total_ans = 0 ptr_S, ptr_T = 0, 0 len_list_S, len_list_T = len(list_S), len(list_T) while ptr_S < len_list_S and ptr_T < len_list_T: hash_S_pair, count_S_val = list_S[ptr_S] hash_T_pair, count_T_val = list_T[ptr_T] if hash_S_pair == hash_T_pair: total_ans += count_S_val * count_T_val ptr_S += 1 ptr_T += 1 elif hash_S_pair < hash_T_pair: ptr_S += 1 else: ptr_T += 1 sys.stdout.write(str(total_ans) + "\n") solve()