結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-05-14 12:52:24
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 4,281 bytes
コンパイル時間 237 ms
コンパイル使用メモリ 82,408 KB
実行使用メモリ 777,116 KB
最終ジャッジ日時 2025-05-14 12:54:25
合計ジャッジ時間 10,040 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 1 MLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

class Node:
    def __init__(self):
        self.len = 0
        self.link = None
        self.next = dict()
        self.cnt = 0
        self.last_pos = -1

class PalindromicTree:
    def __init__(self):
        self.nodes = []
        self.root1 = self.new_node(-1, None)
        self.root1.link = self.root1
        self.root2 = self.new_node(0, self.root1)
        self.last = self.root2

    def new_node(self, length, link):
        node = Node()
        node.len = length
        node.link = link
        self.nodes.append(node)
        return node

    def add_char(self, s, index):
        c = s[index]
        current = self.last
        while True:
            candidate_pos = index - current.len - 1
            if candidate_pos >= 0 and s[candidate_pos] == c:
                break
            current = current.link

        if c in current.next:
            self.last = current.next[c]
            self.last.cnt += 1
            self.last.last_pos = index
            return

        new_node = self.new_node(current.len + 2, self.root2)
        if new_node.len == 1:
            new_node.link = self.root2
        else:
            link_current = current.link
            while True:
                candidate_pos = index - link_current.len - 1
                if candidate_pos >= 0 and s[candidate_pos] == c:
                    break
                link_current = link_current.link
            new_node.link = link_current.next.get(c, self.root2)

        current.next[c] = new_node
        self.last = new_node
        new_node.cnt = 1
        new_node.last_pos = index

    def build_tree(self, s):
        for i in range(len(s)):
            self.add_char(s, i)
        nodes_sorted = sorted(self.nodes, key=lambda x: -x.len)
        for node in nodes_sorted:
            if node.link is not None:
                node.link.cnt += node.cnt

def compute_prefix_hash(s, base, mod):
    n = len(s)
    pre = [0] * (n + 1)
    for i in range(n):
        pre[i+1] = (pre[i] * base + ord(s[i])) % mod
    return pre

def compute_powers(n, base, mod):
    pow = [1] * (n + 1)
    for i in range(1, n+1):
        pow[i] = (pow[i-1] * base) % mod
    return pow

def main():
    s = sys.stdin.readline().strip()
    t = sys.stdin.readline().strip()

    mod1 = 10**18 + 3
    base1 = 911382629
    mod2 = 10**18 + 7
    base2 = 3571428571

    # Precompute hashes and powers for S
    pre_s1 = compute_prefix_hash(s, base1, mod1)
    pre_s2 = compute_prefix_hash(s, base2, mod2)
    len_s = len(s)
    pow_s1 = compute_powers(len_s, base1, mod1)
    pow_s2 = compute_powers(len_s, base2, mod2)

    # Precompute hashes and powers for T
    pre_t1 = compute_prefix_hash(t, base1, mod1)
    pre_t2 = compute_prefix_hash(t, base2, mod2)
    len_t = len(t)
    pow_t1 = compute_powers(len_t, base1, mod1)
    pow_t2 = compute_powers(len_t, base2, mod2)

    # Build palindromic trees for S and T
    pt_s = PalindromicTree()
    pt_s.build_tree(s)
    pt_t = PalindromicTree()
    pt_t.build_tree(t)

    mapS = defaultdict(int)
    mapT = defaultdict(int)

    # Process S's nodes
    for node in pt_s.nodes:
        if node.len <= 0:
            continue
        start = node.last_pos - node.len + 1
        end = node.last_pos
        if start < 0 or end >= len(s):
            continue
        hash1 = (pre_s1[end+1] - pre_s1[start] * pow_s1[end - start + 1]) % mod1
        hash2 = (pre_s2[end+1] - pre_s2[start] * pow_s2[end - start + 1]) % mod2
        hash1 = hash1 % mod1
        hash2 = hash2 % mod2
        mapS[(hash1, hash2)] += node.cnt

    # Process T's nodes
    for node in pt_t.nodes:
        if node.len <= 0:
            continue
        start = node.last_pos - node.len + 1
        end = node.last_pos
        if start < 0 or end >= len(t):
            continue
        hash1 = (pre_t1[end+1] - pre_t1[start] * pow_t1[end - start + 1]) % mod1
        hash2 = (pre_t2[end+1] - pre_t2[start] * pow_t2[end - start + 1]) % mod2
        hash1 = hash1 % mod1
        hash2 = hash2 % mod2
        mapT[(hash1, hash2)] += node.cnt

    # Calculate the answer
    ans = 0
    for key in mapS:
        if key in mapT:
            ans += mapS[key] * mapT[key]
    print(ans)

if __name__ == "__main__":
    main()
0