結果
問題 |
No.263 Common Palindromes Extra
|
ユーザー |
![]() |
提出日時 | 2025-05-14 12:51:35 |
言語 | PyPy3 (7.3.15) |
結果 |
MLE
|
実行時間 | - |
コード長 | 3,677 bytes |
コンパイル時間 | 387 ms |
コンパイル使用メモリ | 81,868 KB |
実行使用メモリ | 828,700 KB |
最終ジャッジ日時 | 2025-05-14 12:52:03 |
合計ジャッジ時間 | 8,498 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 8 TLE * 1 MLE * 3 |
ソースコード
import sys from collections import defaultdict MOD1 = 10**9 + 7 BASE1 = 911382629 MOD2 = 10**9 + 3 BASE2 = 3571428571 def precompute_powers(max_len, base, mod): pow_base = [1] * (max_len + 2) for i in range(1, max_len + 2): pow_base[i] = (pow_base[i-1] * base) % mod return pow_base pow_base1 = precompute_powers(500002, BASE1, MOD1) pow_base2 = precompute_powers(500002, BASE2, MOD2) class PalindromicTree: class Node: def __init__(self): self.next = dict() self.link = 0 self.len = 0 self.hash1 = 0 self.hash2 = 0 self.cnt = 0 def __init__(self): self.nodes = [self.Node(), self.Node()] self.nodes[0].len = -1 self.nodes[0].link = 0 self.nodes[1].len = 0 self.nodes[1].link = 0 self.last = 1 self.size = 2 def get_link(self, u, s, idx): while True: len_node = self.nodes[u].len if idx - len_node - 1 >= 0 and s[idx] == s[idx - len_node - 1]: break u = self.nodes[u].link return u def add_char(self, s, idx): c = s[idx] u = self.get_link(self.last, s, idx) if c in self.nodes[u].next: self.last = self.nodes[u].next[c] self.nodes[self.last].cnt += 1 return new_node = self.Node() new_node.len = self.nodes[u].len + 2 new_node.cnt = 1 if self.nodes[u].len == -1: h1 = (ord(c) - ord('A') + 1) % MOD1 h2 = (ord(c) - ord('A') + 1) % MOD2 else: p_h1 = self.nodes[u].hash1 p_h2 = self.nodes[u].hash2 p_len = self.nodes[u].len part1_1 = ( (ord(c) - ord('A') + 1) * pow_base1[p_len + 1] ) % MOD1 part2_1 = (p_h1 * BASE1) % MOD1 part3_1 = (ord(c) - ord('A') + 1) % MOD1 h1 = (part1_1 + part2_1 + part3_1) % MOD1 part1_2 = ( (ord(c) - ord('A') + 1) * pow_base2[p_len + 1] ) % MOD2 part2_2 = (p_h2 * BASE2) % MOD2 part3_2 = (ord(c) - ord('A') + 1) % MOD2 h2 = (part1_2 + part2_2 + part3_2) % MOD2 new_node.hash1 = h1 new_node.hash2 = h2 link_u = self.nodes[u].link link_u = self.get_link(link_u, s, idx) if c in self.nodes[link_u].next: new_node.link = self.nodes[link_u].next[c] else: new_node.link = 1 self.nodes.append(new_node) self.nodes[u].next[c] = self.size self.last = self.size self.size += 1 def build(self, s): for i in range(len(s)): self.add_char(s, i) order = sorted(range(2, self.size), key=lambda x: -self.nodes[x].len) for u in order: link = self.nodes[u].link if link >= 0 and link < self.size: self.nodes[link].cnt += self.nodes[u].cnt def main(): S = sys.stdin.readline().strip() T = sys.stdin.readline().strip() pt_S = PalindromicTree() pt_S.build(S) count_S = defaultdict(int) for i in range(2, pt_S.size): node = pt_S.nodes[i] if node.len > 0: h = (node.hash1, node.hash2) count_S[h] += node.cnt pt_T = PalindromicTree() pt_T.build(T) count_T = defaultdict(int) for i in range(2, pt_T.size): node = pt_T.nodes[i] if node.len > 0: h = (node.hash1, node.hash2) count_T[h] += node.cnt result = 0 for h in count_S: if h in count_T: result += count_S[h] * count_T[h] print(result) if __name__ == "__main__": main()