結果

問題 No.263 Common Palindromes Extra
ユーザー gew1fw
提出日時 2025-06-12 20:42:44
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 2,634 bytes
コンパイル時間 166 ms
コンパイル使用メモリ 82,280 KB
実行使用メモリ 754,720 KB
最終ジャッジ日時 2025-06-12 20:42:59
合計ジャッジ時間 9,794 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**18 + 3
P = 911382629

max_n = 10**6 + 10
powers = [1] * (max_n + 1)
for i in range(1, max_n + 1):
    powers[i] = (powers[i-1] * P) % MOD

class EertreeNode:
    def __init__(self):
        self.len = 0
        self.link = None
        self.transitions = {}
        self.count = 0
        self.hash = 0

def build_eertree(s):
    root1 = EertreeNode()
    root1.len = -1
    root1.link = root1
    root1.count = 0
    root1.hash = 0

    root2 = EertreeNode()
    root2.len = 0
    root2.link = root1
    root2.count = 0
    root2.hash = 0

    tree = [root1, root2]
    last = root2

    for i, c in enumerate(s):
        current = last
        while True:
            if i - current.len - 1 >= 0 and s[i - current.len - 1] == c:
                break
            current = current.link

        if c in current.transitions:
            last = current.transitions[c]
            last.count += 1
            continue

        new_node = EertreeNode()
        new_node.len = current.len + 2
        new_node.count = 1

        if current.len == -1:
            new_node.hash = ord(c)
        else:
            power = powers[current.len + 1]
            left = (ord(c) * power) % MOD
            middle = (current.hash * P) % MOD
            right = ord(c)
            new_node.hash = (left + middle + right) % MOD

        if new_node.len == 1:
            new_node.link = root2
        else:
            tmp = current.link
            while True:
                if i - tmp.len - 1 >= 0 and s[i - tmp.len - 1] == c:
                    new_node.link = tmp.transitions.get(c, root2)
                    break
                tmp = tmp.link

        current.transitions[c] = new_node
        tree.append(new_node)
        last = new_node

    nodes = [node for node in tree if node != root1 and node != root2]
    nodes.sort(key=lambda x: -x.len)

    for node in nodes:
        if node.link is not None and node.link != node:
            node.link.count += node.count

    return tree

def create_dict(tree):
    root1 = tree[0]
    root2 = tree[1]
    d = {}
    for node in tree:
        if node == root1 or node == root2:
            continue
        d[node.hash] = node.count
    return d

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    S = data[0]
    T = data[1]

    s_tree = build_eertree(S)
    t_tree = build_eertree(T)

    s_dict = create_dict(s_tree)
    t_dict = create_dict(t_tree)

    answer = 0
    for hash_val in s_dict:
        if hash_val in t_dict:
            answer += s_dict[hash_val] * t_dict[hash_val]
    print(answer)

if __name__ == "__main__":
    main()
0