結果

問題 No.263 Common Palindromes Extra
ユーザー lam6er
提出日時 2025-04-09 21:05:18
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,139 bytes
コンパイル時間 353 ms
コンパイル使用メモリ 82,644 KB
実行使用メモリ 132,348 KB
最終ジャッジ日時 2025-04-09 21:06:55
合計ジャッジ時間 3,596 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 2 WA * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    MOD1 = 10**18 + 3
    BASE1 = 911382629
    MOD2 = 10**18 + 9
    BASE2 = 3571428571

    # Precompute powers for both moduli up to the maximum possible length (5e5 + 2)
    max_len = 500000 + 2
    pow1 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow1[i] = (pow1[i-1] * BASE1) % MOD1

    pow2 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow2[i] = (pow2[i-1] * BASE2) % MOD2

    class EertreeNode:
        __slots__ = ['length', 'suffix_link', 'trans', 'count', 'hash1', 'hash2']
        def __init__(self, length, suffix_link):
            self.length = length
            self.suffix_link = suffix_link
            self.trans = dict()
            self.count = 0
            self.hash1 = 0
            self.hash2 = 0

    def build_eertree(s, pow1, MOD1, BASE1, pow2, MOD2, BASE2):
        root_neg = EertreeNode(-1, None)
        root_neg.suffix_link = root_neg
        root_0 = EertreeNode(0, root_neg)
        root_neg.hash1 = 0
        root_neg.hash2 = 0
        root_0.hash1 = 0
        root_0.hash2 = 0

        tree = [root_neg, root_0]
        last = root_0
        s_list = list(s)
        n = len(s)

        for idx in range(n):
            c = s_list[idx]
            current = last
            while True:
                current = current.suffix_link
                current_length = current.length
                left_pos = idx - current_length - 1
                if left_pos >= 0 and s_list[left_pos] == c:
                    break

            if c in current.trans:
                last = current.trans[c]
                last.count += 1
                continue

            new_node = EertreeNode(current.length + 2, None)
            tree.append(new_node)
            current.trans[c] = new_node
            new_node.count = 1

            if new_node.length == 1:
                new_node.suffix_link = root_0
                new_node.hash1 = ord(c) % MOD1
                new_node.hash2 = ord(c) % MOD2
            else:
                suffix_current = current.suffix_link
                while True:
                    suffix_current = suffix_current.suffix_link
                    left_pos_suffix = idx - suffix_current.length - 1
                    if left_pos_suffix >= 0 and s_list[left_pos_suffix] == c:
                        break
                if c in suffix_current.trans:
                    suffix_link_node = suffix_current.trans[c]
                else:
                    suffix_link_node = root_0
                new_node.suffix_link = suffix_link_node

                parent = current
                pl = parent.length
                h1_parent = parent.hash1
                h2_parent = parent.hash2

                c_ord = ord(c)

                if pl == -1:
                    new_hash1 = c_ord % MOD1
                    new_hash2 = c_ord % MOD2
                else:
                    new_hash1 = (c_ord * pow1[pl + 1] + h1_parent * BASE1 + c_ord) % MOD1
                    new_hash2 = (c_ord * pow2[pl + 1] + h2_parent * BASE2 + c_ord) % MOD2

                new_node.hash1 = new_hash1
                new_node.hash2 = new_hash2

            last = new_node

        for i in reversed(range(len(tree))):
            node = tree[i]
            if node.suffix_link is not None and node.suffix_link != node:
                node.suffix_link.count += node.count

        return tree

    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()

    tree_S = build_eertree(S, pow1, MOD1, BASE1, pow2, MOD2, BASE2)
    tree_T = build_eertree(T, pow1, MOD1, BASE1, pow2, MOD2, BASE2)

    freq_S = {}
    for node in tree_S:
        if node.length > 0:
            key = (node.hash1, node.hash2)
            freq_S[key] = freq_S.get(key, 0) + node.count

    freq_T = {}
    for node in tree_T:
        if node.length > 0:
            key = (node.hash1, node.hash2)
            freq_T[key] = freq_T.get(key, 0) + node.count

    total = 0
    for key in freq_S:
        if key in freq_T:
            total += freq_S[key] * freq_T[key]

    print(total)

if __name__ == '__main__':
    main()
0