結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-04-24 12:26:05
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,672 bytes
コンパイル時間 1,124 ms
コンパイル使用メモリ 81,940 KB
実行使用メモリ 527,968 KB
最終ジャッジ日時 2025-04-24 12:27:08
合計ジャッジ時間 11,530 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()

    base1 = 911382629
    mod1 = 10**18 + 3
    base2 = 3571428571
    mod2 = 10**18 + 7

    max_len = 10**6 + 10  # Sufficiently large to handle maximum possible length

    # Precompute pow_base arrays for both bases and mods
    pow_base1 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow_base1[i] = (pow_base1[i-1] * base1) % mod1

    pow_base2 = [1] * (max_len + 1)
    for i in range(1, max_len + 1):
        pow_base2[i] = (pow_base2[i-1] * base2) % mod2

    def build_pal_tree(s):
        class Node:
            __slots__ = ['len', 'link', 'next', 'h1', 'h2', 'cnt']
            def __init__(self):
                self.len = 0
                self.link = None
                self.next = dict()
                self.h1 = 0
                self.h2 = 0
                self.cnt = 0

        root0 = Node()
        root0.len = 0
        root1 = Node()
        root1.len = -1
        root1.link = root1
        root0.link = root1

        last = root0
        nodes = [root0, root1]

        for i in range(len(s)):
            c = s[i]
            current = last
            while True:
                # Calculate the position of the previous character
                pos = i - current.len - 1
                if pos >= 0 and s[pos] == c:
                    break
                current = current.link

            if c in current.next:
                last = current.next[c]
                last.cnt += 1
                continue

            # Create a new node
            new_node = Node()
            new_node.len = current.len + 2

            if current is root1:
                new_node.h1 = ord(c)
                new_node.h2 = ord(c)
            else:
                exponent = current.len + 1
                h1_part = (ord(c) * pow_base1[exponent]) % mod1
                h1_mid = (current.h1 * base1) % mod1
                new_h1 = (h1_part + h1_mid + ord(c)) % mod1

                h2_part = (ord(c) * pow_base2[exponent]) % mod2
                h2_mid = (current.h2 * base2) % mod2
                new_h2 = (h2_part + h2_mid + ord(c)) % mod2

                new_node.h1 = new_h1
                new_node.h2 = new_h2

            # Find the link for the new node
            if new_node.len == 1:
                new_node.link = root0
            else:
                tmp = current.link
                while True:
                    pos = i - tmp.len - 1
                    if pos >= 0 and s[pos] == c:
                        break
                    tmp = tmp.link
                # Check if tmp has the 'c' child
                new_node.link = tmp.next.get(c, root0)

            current.next[c] = new_node
            nodes.append(new_node)
            last = new_node
            last.cnt = 1

        # Propagate the counts
        nodes_sorted = sorted(nodes, key=lambda x: -x.len)
        for node in nodes_sorted:
            if node.link is not None:
                node.link.cnt += node.cnt

        # Collect hash map entries
        hash_map = defaultdict(int)
        for node in nodes:
            if node is root0 or node is root1:
                continue
            hash_key = (node.h1, node.h2)
            hash_map[hash_key] += node.cnt

        return hash_map

    # Build hash maps for S and T
    hash_s = build_pal_tree(S)
    hash_t = build_pal_tree(T)

    # Calculate the result
    result = 0
    for key in hash_s:
        if key in hash_t:
            result += hash_s[key] * hash_t[key]

    print(result)

if __name__ == '__main__':
    main()
0