結果

問題 No.263 Common Palindromes Extra
ユーザー gew1fw
提出日時 2025-06-12 20:45:27
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 2,946 bytes
コンパイル時間 297 ms
コンパイル使用メモリ 82,276 KB
実行使用メモリ 539,204 KB
最終ジャッジ日時 2025-06-12 20:45:35
合計ジャッジ時間 6,907 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 MLE * 4
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

class EertreeNode:
    def __init__(self):
        self.length = 0
        self.edges = dict()
        self.suffix_link = None
        self.hash_val = 0
        self.count = 0

def compute_hash(c_val, parent, power, P, mod):
    if parent.length == -1:
        return c_val % mod
    else:
        term1 = (c_val * (power[parent.length + 1] + 1)) % mod
        term2 = (parent.hash_val * power[1]) % mod
        new_hash = (term1 + term2) % mod
        return new_hash

def build_eertree(s, power, P, mod):
    root = EertreeNode()
    root.length = -1
    root.suffix_link = root
    root.hash_val = 0

    another_root = EertreeNode()
    another_root.length = 0
    another_root.suffix_link = root
    another_root.hash_val = 0

    tree = [root, another_root]
    last = another_root

    for i, c in enumerate(s):
        current = last
        c_val = ord(c) - ord('A') + 1

        while True:
            edge_length = current.length
            if i - edge_length - 1 >= 0 and s[i - edge_length - 1] == c:
                break
            current = current.suffix_link

        if c in current.edges:
            last = current.edges[c]
            last.count += 1
            continue

        new_node = EertreeNode()
        new_node.length = current.length + 2
        new_node.edges = dict()
        new_node.suffix_link = None
        new_node.hash_val = compute_hash(c_val, current, power, P, mod)
        new_node.count = 1

        tree.append(new_node)
        current.edges[c] = new_node

        if new_node.length == 1:
            new_node.suffix_link = another_root
        else:
            temp = current.suffix_link
            while True:
                edge_length = temp.length
                if i - edge_length - 1 >= 0 and s[i - edge_length - 1] == c:
                    new_node.suffix_link = temp.edges.get(c, another_root)
                    break
                temp = temp.suffix_link

        last = new_node

    nodes = sorted(tree, key=lambda x: -x.length)
    for node in nodes:
        if node.suffix_link is not None and node.suffix_link.length != -1:
            node.suffix_link.count += node.count

    hash_map = dict()
    for node in tree:
        if node.length == -1 or node.length == 0:
            continue
        key = node.hash_val
        if key not in hash_map:
            hash_map[key] = 0
        hash_map[key] += node.count

    return hash_map

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()

    P = 911382629
    mod = 10**18 + 3

    max_len = 500000
    power = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        power[i] = (power[i-1] * P) % mod

    hash_S = build_eertree(S, power, P, mod)
    hash_T = build_eertree(T, power, P, mod)

    result = 0
    for key in hash_S:
        if key in hash_T:
            result += hash_S[key] * hash_T[key]

    print(result)

if __name__ == "__main__":
    main()
0