結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-04-24 12:26:05
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,677 bytes
コンパイル時間 215 ms
コンパイル使用メモリ 82,160 KB
実行使用メモリ 828,728 KB
最終ジャッジ日時 2025-04-24 12:26:57
合計ジャッジ時間 8,324 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

MOD1 = 10**9 + 7
BASE1 = 911382629
MOD2 = 10**9 + 3
BASE2 = 3571428571

def precompute_powers(max_len, base, mod):
    pow_base = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_base[i] = (pow_base[i-1] * base) % mod
    return pow_base

pow_base1 = precompute_powers(500002, BASE1, MOD1)
pow_base2 = precompute_powers(500002, BASE2, MOD2)

class PalindromicTree:
    class Node:
        def __init__(self):
            self.next = dict()
            self.link = 0
            self.len = 0
            self.hash1 = 0
            self.hash2 = 0
            self.cnt = 0

    def __init__(self):
        self.nodes = [self.Node(), self.Node()]
        self.nodes[0].len = -1
        self.nodes[0].link = 0
        self.nodes[1].len = 0
        self.nodes[1].link = 0
        self.last = 1
        self.size = 2

    def get_link(self, u, s, idx):
        while True:
            len_node = self.nodes[u].len
            if idx - len_node - 1 >= 0 and s[idx] == s[idx - len_node - 1]:
                break
            u = self.nodes[u].link
        return u

    def add_char(self, s, idx):
        c = s[idx]
        u = self.get_link(self.last, s, idx)
        if c in self.nodes[u].next:
            self.last = self.nodes[u].next[c]
            self.nodes[self.last].cnt += 1
            return
        new_node = self.Node()
        new_node.len = self.nodes[u].len + 2
        new_node.cnt = 1

        if self.nodes[u].len == -1:
            h1 = (ord(c) - ord('A') + 1) % MOD1
            h2 = (ord(c) - ord('A') + 1) % MOD2
        else:
            p_h1 = self.nodes[u].hash1
            p_h2 = self.nodes[u].hash2
            p_len = self.nodes[u].len
            part1_1 = ( (ord(c) - ord('A') + 1) * pow_base1[p_len + 1] ) % MOD1
            part2_1 = (p_h1 * BASE1) % MOD1
            part3_1 = (ord(c) - ord('A') + 1) % MOD1
            h1 = (part1_1 + part2_1 + part3_1) % MOD1

            part1_2 = ( (ord(c) - ord('A') + 1) * pow_base2[p_len + 1] ) % MOD2
            part2_2 = (p_h2 * BASE2) % MOD2
            part3_2 = (ord(c) - ord('A') + 1) % MOD2
            h2 = (part1_2 + part2_2 + part3_2) % MOD2

        new_node.hash1 = h1
        new_node.hash2 = h2

        link_u = self.nodes[u].link
        link_u = self.get_link(link_u, s, idx)
        if c in self.nodes[link_u].next:
            new_node.link = self.nodes[link_u].next[c]
        else:
            new_node.link = 1

        self.nodes.append(new_node)
        self.nodes[u].next[c] = self.size
        self.last = self.size
        self.size += 1

    def build(self, s):
        for i in range(len(s)):
            self.add_char(s, i)
        order = sorted(range(2, self.size), key=lambda x: -self.nodes[x].len)
        for u in order:
            link = self.nodes[u].link
            if link >= 0 and link < self.size:
                self.nodes[link].cnt += self.nodes[u].cnt

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()

    pt_S = PalindromicTree()
    pt_S.build(S)
    count_S = defaultdict(int)
    for i in range(2, pt_S.size):
        node = pt_S.nodes[i]
        if node.len > 0:
            h = (node.hash1, node.hash2)
            count_S[h] += node.cnt

    pt_T = PalindromicTree()
    pt_T.build(T)
    count_T = defaultdict(int)
    for i in range(2, pt_T.size):
        node = pt_T.nodes[i]
        if node.len > 0:
            h = (node.hash1, node.hash2)
            count_T[h] += node.cnt

    result = 0
    for h in count_S:
        if h in count_T:
            result += count_S[h] * count_T[h]
    print(result)

if __name__ == "__main__":
    main()
0