結果

問題 No.263 Common Palindromes Extra
ユーザー lam6er
提出日時 2025-03-20 19:02:44
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,260 bytes
コンパイル時間 169 ms
コンパイル使用メモリ 82,844 KB
実行使用メモリ 449,944 KB
最終ジャッジ日時 2025-03-20 19:03:14
合計ジャッジ時間 8,016 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 MLE * 4
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

MOD1 = 10**18 + 3
MOD2 = 10**18 + 7
BASE1 = 911382629
BASE2 = 3571428571

def main():
    S = sys.stdin.readline().strip()
    T = sys.stdin.readline().strip()
    
    max_len = max(len(S), len(T))
    
    pow_base1 = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_base1[i] = (pow_base1[i-1] * BASE1) % MOD1
    pow_base2 = [1] * (max_len + 2)
    for i in range(1, max_len + 2):
        pow_base2[i] = (pow_base2[i-1] * BASE2) % MOD2
    
    class Node:
        __slots__ = ['len', 'suff_link', 'trans', 'h1', 'h2', 'count']
        def __init__(self):
            self.len = 0
            self.suff_link = None
            self.trans = dict()
            self.h1 = 0  # hash1
            self.h2 = 0  # hash2
            self.count = 0
    
    def build_tree(s, pow_base1, pow_base2):
        root_neg1 = Node()
        root_neg1.len = -1
        root_neg1.suff_link = root_neg1
        root0 = Node()
        root0.len = 0
        root0.suff_link = root_neg1
        tree = [root_neg1, root0]
        last = root0
        s_chars = [ord(c) - ord('A') + 1 for c in s]
        for idx, c in enumerate(s_chars):
            current = last
            while True:
                edge_len = current.len
                pos = idx - edge_len - 1
                if pos >= 0 and s_chars[pos] == c:
                    break
                current = current.suff_link
            if c in current.trans:
                last = current.trans[c]
                last.count += 1
                continue
            new_node = Node()
            new_node.len = current.len + 2
            tree.append(new_node)
            current.trans[c] = new_node
            if new_node.len == 1:
                new_node.suff_link = root0
                new_node.h1 = c % MOD1
                new_node.h2 = c % MOD2
            else:
                parent = current
                new_node.h1 = (c * pow_base1[parent.len + 1] % MOD1 + parent.h1 * BASE1 % MOD1 + c) % MOD1
                new_node.h2 = (c * pow_base2[parent.len + 1] % MOD2 + parent.h2 * BASE2 % MOD2 + c) % MOD2
                suff = parent.suff_link
                while True:
                    edge_len_suff = suff.len
                    pos_suff = idx - edge_len_suff - 1
                    if pos_suff >= 0 and s_chars[pos_suff] == c:
                        break
                    suff = suff.suff_link
                if c in suff.trans:
                    new_node.suff_link = suff.trans[c]
                else:
                    new_node.suff_link = root0
            new_node.count = 1
            last = new_node
        for node in reversed(tree):
            if node.suff_link is not None:
                node.suff_link.count += node.count
        hash_map = defaultdict(int)
        for node in tree:
            if node.len > 0 and node.count > 0:
                hash_map[(node.h1, node.h2)] += node.count
        return hash_map
    
    hash_s = build_tree(S, pow_base1, pow_base2)
    hash_t = build_tree(T, pow_base1, pow_base2)
    
    total = 0
    for key in hash_s:
        if key in hash_t:
            total += hash_s[key] * hash_t[key]
    print(total)
    
if __name__ == "__main__":
    main()
0