結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-04-24 12:21:53
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,232 bytes
コンパイル時間 179 ms
コンパイル使用メモリ 82,528 KB
実行使用メモリ 537,868 KB
最終ジャッジ日時 2025-04-24 12:23:49
合計ジャッジ時間 11,968 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

MOD1 = 10**18 + 3
BASE1 = 911382629
MOD2 = 10**18 + 7
BASE2 = 3571428571

def precompute_powers(max_len, base, mod):
    pow_list = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_list[i] = (pow_list[i-1] * base) % mod
    return pow_list

pow1 = precompute_powers(10**6 + 2, BASE1, MOD1)
pow2 = precompute_powers(10**6 + 2, BASE2, MOD2)

class Node:
    __slots__ = ['length', 'suffix_link', 'transitions', 'count', 'hash1', 'hash2']
    def __init__(self, length, suffix_link):
        self.length = length
        self.suffix_link = suffix_link
        self.transitions = dict()
        self.count = 0
        self.hash1 = 0
        self.hash2 = 0

def build_eertree(s):
    root_neg = Node(-1, None)
    root_neg.suffix_link = root_neg
    root_0 = Node(0, root_neg)
    root_neg.hash1 = 0
    root_neg.hash2 = 0
    root_0.hash1 = 0
    root_0.hash2 = 0
    nodes = []
    last = root_0
    for idx, c in enumerate(s):
        current = last
        while True:
            link_length = current.length
            prev_pos = idx - link_length - 1
            if prev_pos >= 0 and s[prev_pos] == c:
                break
            current = current.suffix_link
        if c in current.transitions:
            last = current.transitions[c]
            last.count += 1
            continue
        new_node = Node(link_length + 2, None)
        nodes.append(new_node)
        if current is root_neg:
            new_node.hash1 = ord(c) % MOD1
            new_node.hash2 = ord(c) % MOD2
        else:
            exponent1 = current.length + 1
            h1 = (ord(c) * pow1[exponent1] + current.hash1 * BASE1 + ord(c)) % MOD1
            new_node.hash1 = h1
            exponent2 = current.length + 1
            h2 = (ord(c) * pow2[exponent2] + current.hash2 * BASE2 + ord(c)) % MOD2
            new_node.hash2 = h2
        suffix_candidate = current.suffix_link
        while True:
            link_length = suffix_candidate.length
            prev_pos = idx - link_length - 1
            if prev_pos >= 0 and s[prev_pos] == c:
                if c in suffix_candidate.transitions:
                    new_node.suffix_link = suffix_candidate.transitions[c]
                else:
                    new_node.suffix_link = root_0 if new_node.length == 1 else root_neg
                break
            suffix_candidate = suffix_candidate.suffix_link
        current.transitions[c] = new_node
        new_node.count = 1
        last = new_node
    all_nodes = [root_0, root_neg] + nodes
    nodes_sorted = sorted(nodes, key=lambda x: -x.length)
    for node in nodes_sorted:
        if node.suffix_link is not None and node.suffix_link.length >= 0:
            node.suffix_link.count += node.count
    hash_map = defaultdict(int)
    for node in nodes:
        key = (node.hash1, node.hash2)
        hash_map[key] += node.count
    return hash_map

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()
    map_S = build_eertree(S)
    map_T = build_eertree(T)
    result = 0
    for key in map_S:
        if key in map_T:
            result += map_S[key] * map_T[key]
    print(result)

if __name__ == "__main__":
    main()
0