結果

問題 No.263 Common Palindromes Extra
ユーザー lam6er
提出日時 2025-04-09 21:05:23
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,647 bytes
コンパイル時間 240 ms
コンパイル使用メモリ 82,576 KB
実行使用メモリ 645,640 KB
最終ジャッジ日時 2025-04-09 21:07:30
合計ジャッジ時間 9,966 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 8 TLE * 2 MLE * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

class PalindromicNode:
    __slots__ = ['len', 'suffix_link', 'next', 'count', 'h1', 'h2', 'pow1', 'pow2']
    def __init__(self, length, suffix_link):
        self.len = length
        self.suffix_link = suffix_link
        self.next = [None] * 26  # Using list for next pointers (index 0-25 for 'A'-'Z')
        self.count = 0
        self.h1 = 0
        self.h2 = 0
        self.pow1 = 0
        self.pow2 = 0

def build_eertree(s, base, mod1, mod2):
    root_neg = PalindromicNode(-1, None)
    root_neg.suffix_link = root_neg
    root_0 = PalindromicNode(0, root_neg)
    
    inv_base_mod1 = pow(base, mod1 - 2, mod1)
    inv_base_mod2 = pow(base, mod2 - 2, mod2)
    
    root_neg.pow1 = inv_base_mod1
    root_neg.pow2 = inv_base_mod2
    root_neg.h1 = 0
    root_neg.h2 = 0
    
    root_0.pow1 = 1
    root_0.h1 = 0
    root_0.pow2 = 1
    root_0.h2 = 0
    
    tree = [root_neg, root_0]
    last = root_0
    
    for idx, ch in enumerate(s):
        current = last
        c = ord(ch) - ord('A')
        c_val = c + 1  # 1-26
        
        while True:
            candidate_pos = idx - current.len - 1
            if candidate_pos >= 0 and s[candidate_pos] == ch:
                break
            current = current.suffix_link
        
        if current.next[c] is not None:
            last = current.next[c]
            last.count += 1
            continue
        
        new_node = PalindromicNode(current.len + 2, None)
        tree.append(new_node)
        current.next[c] = new_node
        
        if current.len == -1:
            new_node.h1 = c_val % mod1
            new_node.pow1 = base % mod1
            new_node.h2 = c_val % mod2
            new_node.pow2 = base % mod2
        else:
            new_h1 = (c_val * current.pow1 * base) % mod1
            new_h1 = (new_h1 + current.h1 * base) % mod1
            new_h1 = (new_h1 + c_val) % mod1
            new_node.h1 = new_h1
            new_node.pow1 = (current.pow1 * (base * base)) % mod1
            
            new_h2 = (c_val * current.pow2 * base) % mod2
            new_h2 = (new_h2 + current.h2 * base) % mod2
            new_h2 = (new_h2 + c_val) % mod2
            new_node.h2 = new_h2
            new_node.pow2 = (current.pow2 * (base * base)) % mod2
        
        if new_node.len == 1:
            new_node.suffix_link = root_0
        else:
            suffix = current.suffix_link
            while True:
                candidate_pos = idx - suffix.len - 1
                if candidate_pos >= 0 and s[candidate_pos] == ch:
                    break
                suffix = suffix.suffix_link
            if suffix.next[c] is not None:
                new_node.suffix_link = suffix.next[c]
            else:
                new_node.suffix_link = root_0
        
        new_node.count = 1
        last = new_node
    
    nodes_sorted = sorted(tree, key=lambda x: -x.len)
    for node in nodes_sorted:
        if node.suffix_link is not None and node.suffix_link != node:
            node.suffix_link.count += node.count
    
    return tree

def compute_hash_counts(tree):
    hash_counts = {}
    for node in tree:
        if node.len > 0:
            key = (node.h1, node.h2)
            hash_counts[key] = hash_counts.get(key, 0) + node.count
    return hash_counts

MOD1 = 10**9 + 7
MOD2 = 10**18 + 3
BASE = 911382629

s = input().strip()
t = input().strip()

tree_s = build_eertree(s, BASE, MOD1, MOD2)
hash_s = compute_hash_counts(tree_s)

tree_t = build_eertree(t, BASE, MOD1, MOD2)
hash_t = compute_hash_counts(tree_t)

result = 0
for key in hash_s:
    if key in hash_t:
        result += hash_s[key] * hash_t[key]

print(result)
0