結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-04-24 12:26:02
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,604 bytes
コンパイル時間 163 ms
コンパイル使用メモリ 81,992 KB
実行使用メモリ 526,040 KB
最終ジャッジ日時 2025-04-24 12:27:04
合計ジャッジ時間 8,740 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

class Node:
    __slots__ = ['len', 'fail', 'next', 'cnt', 'h1', 'h2']
    def __init__(self):
        self.len = 0
        self.fail = None
        self.next = [None] * 26  # 'A'-'Z' mapped to 0-25
        self.cnt = 0
        self.h1 = 0
        self.h2 = 0

def build_eertree(s, pow_base1, pow_base2, mod1, mod2, base1, base2):
    root_odd = Node()
    root_odd.len = -1
    root_odd.fail = root_odd
    root_even = Node()
    root_even.len = 0
    root_even.fail = root_odd
    tree = [root_odd, root_even]
    last = root_odd
    s_array = [c for c in s]
    for i in range(len(s_array)):
        c = s_array[i]
        c_idx = ord(c) - ord('A')
        current = last
        while True:
            # Check if current can be extended by adding c before and after
            # The position to check is i - current.len - 1
            if i - current.len - 1 >= 0 and s_array[i - current.len - 1] == c:
                break
            current = current.fail
        # Now, current is the node that can be extended
        if current.next[c_idx] is not None:
            last = current.next[c_idx]
            last.cnt += 1
            continue
        # Create new node
        new_node = Node()
        tree.append(new_node)
        new_node.len = current.len + 2
        # Compute hash values
        c_val = ord(c)
        if current.len == -1:
            new_node.h1 = c_val
            new_node.h2 = c_val
        else:
            new_node.h1 = (c_val * pow_base1[current.len + 1] + current.h1 * base1 + c_val) % mod1
            new_node.h2 = (c_val * pow_base2[current.len + 1] + current.h2 * base2 + c_val) % mod2
        # Find fail node
        if new_node.len == 1:
            new_node.fail = root_even
        else:
            temp = current.fail
            while True:
                if i - temp.len - 1 >= 0 and s_array[i - temp.len - 1] == c:
                    if temp.next[c_idx] is not None:
                        new_node.fail = temp.next[c_idx]
                        break
                    else:
                        temp = temp.fail
                else:
                    temp = temp.fail
        current.next[c_idx] = new_node
        new_node.cnt = 1
        last = new_node
    # Accumulate counts
    nodes_sorted = sorted(tree, key=lambda x: -x.len)
    for node in nodes_sorted:
        if node.fail is not None and node != node.fail:
            node.fail.cnt += node.cnt
    # Collect hash counts
    count = defaultdict(int)
    for node in tree:
        if node.len > 0:
            key = (node.h1, node.h2)
            count[key] += node.cnt
    return count

def main():
    s = sys.stdin.readline().strip()
    t = sys.stdin.readline().strip()
    if not s or not t:
        print(0)
        return
    max_len = max(len(s), len(t)) + 2
    base1 = 911382629
    mod1 = 10**18 + 3
    base2 = 3571428571
    mod2 = 10**18 + 7
    # Precompute pow_base arrays
    pow_base1 = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_base1[i] = (pow_base1[i-1] * base1) % mod1
    pow_base2 = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_base2[i] = (pow_base2[i-1] * base2) % mod2
    # Build for S and T
    count_s = build_eertree(s, pow_base1, pow_base2, mod1, mod2, base1, base2)
    count_t = build_eertree(t, pow_base1, pow_base2, mod1, mod2, base1, base2)
    # Calculate the result
    result = 0
    for key in count_s:
        if key in count_t:
            result += count_s[key] * count_t[key]
    print(result)

if __name__ == "__main__":
    main()
0