結果

問題 No.263 Common Palindromes Extra
ユーザー gew1fw
提出日時 2025-06-12 19:12:36
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,740 bytes
コンパイル時間 258 ms
コンパイル使用メモリ 82,324 KB
実行使用メモリ 499,076 KB
最終ジャッジ日時 2025-06-12 19:12:48
合計ジャッジ時間 9,690 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

class EertreeNode:
    __slots__ = ['length', 'suffix_link', 'transitions', 'count', 'hash1', 'hash2']
    def __init__(self, length, suffix_link, hash1, hash2):
        self.length = length
        self.suffix_link = suffix_link
        self.transitions = dict()
        self.count = 0
        self.hash1 = hash1
        self.hash2 = hash2

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()

    # Precompute powers for two different bases and mods
    base1 = 911382629
    mod1 = 10**18 + 3
    base2 = 3571428571
    mod2 = 10**18 + 7
    max_len = max(len(S), len(T)) + 2
    pow1 = [1] * (max_len + 1)
    pow2 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow1[i] = (pow1[i-1] * base1) % mod1
        pow2[i] = (pow2[i-1] * base2) % mod2

    def build_eertree(s, pow1, pow2, base1, mod1, base2, mod2):
        root1 = EertreeNode(-1, None, 0, 0)
        root1.suffix_link = root1
        root2 = EertreeNode(0, root1, 0, 0)
        root2.suffix_link = root1
        tree = [root1, root2]
        last = root2

        for i in range(len(s)):
            c = s[i]
            current = last
            while True:
                edge_len = current.length
                check_pos = i - edge_len - 1
                if check_pos >= 0 and s[check_pos] == c:
                    break
                current = current.suffix_link

            if c in current.transitions:
                last = current.transitions[c]
                last.count += 1
            else:
                new_length = current.length + 2
                if current.length == -1:
                    h1 = ord(c) % mod1
                    h2 = ord(c) % mod2
                else:
                    h1 = (ord(c) * pow1[current.length + 1] + current.hash1 * base1 + ord(c)) % mod1
                    h2 = (ord(c) * pow2[current.length + 1] + current.hash2 * base2 + ord(c)) % mod2
                new_node = EertreeNode(new_length, None, h1, h2)
                if new_length == 1:
                    new_node.suffix_link = root2
                else:
                    suffix_current = current.suffix_link
                    while True:
                        edge_len_suffix = suffix_current.length
                        check_pos_suffix = i - edge_len_suffix - 1
                        if check_pos_suffix >= 0 and s[check_pos_suffix] == c:
                            break
                        suffix_current = suffix_current.suffix_link
                    if c in suffix_current.transitions:
                        new_node.suffix_link = suffix_current.transitions[c]
                    else:
                        new_node.suffix_link = root2
                current.transitions[c] = new_node
                tree.append(new_node)
                last = new_node
                last.count = 1

        # Propagate counts
        nodes = sorted(tree, key=lambda x: -x.length)
        for node in nodes:
            if node.suffix_link is not None and node.suffix_link != node:
                node.suffix_link.count += node.count

        # Collect hash counts, excluding roots (length <=0)
        hash_counts = defaultdict(int)
        for node in tree:
            if node.length > 0:
                key = (node.hash1, node.hash2)
                hash_counts[key] += node.count
        return hash_counts

    s_counts = build_eertree(S, pow1, pow2, base1, mod1, base2, mod2)
    t_counts = build_eertree(T, pow1, pow2, base1, mod1, base2, mod2)

    result = 0
    for key in s_counts:
        if key in t_counts:
            result += s_counts[key] * t_counts[key]
    print(result)

if __name__ == "__main__":
    main()
0