結果

問題 No.263 Common Palindromes Extra
ユーザー lam6er
提出日時 2025-03-26 15:54:21
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 4,014 bytes
コンパイル時間 233 ms
コンパイル使用メモリ 82,040 KB
実行使用メモリ 550,772 KB
最終ジャッジ日時 2025-03-26 15:55:26
合計ジャッジ時間 9,118 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

class Node:
    __slots__ = ['len', 'link', 'next', 'cnt', 'h1', 'h2']
    def __init__(self):
        self.len = 0
        self.link = 0
        self.next = [0] * 26  # 'A' to 'Z'
        self.cnt = 0
        self.h1 = 0
        self.h2 = 0

def build_pal_tree(s, base1, base2, mod1, mod2, pow_base1, pow_base2):
    nodes = []
    # Initialize root nodes
    node = Node()
    node.len = -1
    node.link = 0
    nodes.append(node)
    node = Node()
    node.len = 0
    node.link = 0
    nodes.append(node)
    size = 2
    last = 1  # start with the empty node

    for i, c in enumerate(s):
        idx = ord(c) - ord('A')
        # Find the appropriate parent node
        p = last
        while True:
            current_len = nodes[p].len
            # Check if the character at position i - current_len - 1 is c
            if i - current_len - 1 >= 0 and s[i - current_len - 1] == c:
                break
            p = nodes[p].link
        # Check if the child exists
        if nodes[p].next[idx] != 0:
            last = nodes[p].next[idx]
            nodes[last].cnt += 1
            continue
        # Create new node
        new_node = Node()
        new_node.len = nodes[p].len + 2
        # Compute hash values
        if nodes[p].len == -1 and new_node.len == 1:
            # Special case for single character
            new_node.h1 = ord(c)
            new_node.h2 = ord(c)
        else:
            # Compute h1 and h2 using the parent's hash values
            new_node.h1 = (ord(c) * pow_base1[nodes[p].len + 1] + nodes[p].h1 * base1 + ord(c)) % mod1
            new_node.h2 = (ord(c) * pow_base2[nodes[p].len + 1] + nodes[p].h2 * base2 + ord(c)) % mod2
        # Find the link for the new node
        # Start from the parent's link
        new_link = 0
        if new_node.len == 1:
            new_link = 1  # link to empty node
        else:
            # Find the longest proper suffix which is a palindrome
            link_p = nodes[p].link
            while True:
                current_len_link = nodes[link_p].len
                if i - current_len_link - 1 >= 0 and s[i - current_len_link - 1] == c:
                    break
                link_p = nodes[link_p].link
            new_link = nodes[link_p].next[idx]
            if new_link == 0:
                new_link = 1  # default to empty node if not found (should not happen)
        new_node.link = new_link
        # Add the new node to the list
        nodes.append(new_node)
        size += 1
        last = size - 1
        nodes[p].next[idx] = last
        new_node.cnt = 1

    # Accumulate counts
    # Sort nodes by length in descending order
    nodes_sorted = sorted(nodes[2:], key=lambda x: -x.len)
    for node in nodes_sorted:
        if node.link < len(nodes) and node.link >= 0:
            nodes[node.link].cnt += node.cnt

    # Collect all hashes except the two roots
    hash_map = defaultdict(int)
    for node in nodes[2:]:
        key = (node.h1, node.h2)
        hash_map[key] += node.cnt
    return hash_map

def main():
    s = sys.stdin.readline().strip()
    t = sys.stdin.readline().strip()

    # Precompute powers for hash calculation
    base1 = 911382629
    mod1 = 10**18 + 3
    base2 = 3571428571
    mod2 = 10**18 + 7
    max_len = max(len(s), len(t)) + 2
    # Precompute pow_base arrays
    pow_base1 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow_base1[i] = (pow_base1[i-1] * base1) % mod1
    pow_base2 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow_base2[i] = (pow_base2[i-1] * base2) % mod2

    # Build hash maps for both strings
    map_S = build_pal_tree(s, base1, base2, mod1, mod2, pow_base1, pow_base2)
    map_T = build_pal_tree(t, base1, base2, mod1, mod2, pow_base1, pow_base2)

    # Calculate the result
    result = 0
    for key in map_T:
        if key in map_S:
            result += map_S[key] * map_T[key]
    print(result)

if __name__ == "__main__":
    main()
0