結果

問題 No.263 Common Palindromes Extra
ユーザー qwewe
提出日時 2025-05-14 13:20:08
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 5,836 bytes
コンパイル時間 256 ms
コンパイル使用メモリ 82,180 KB
実行使用メモリ 283,176 KB
最終ジャッジ日時 2025-05-14 13:21:12
合計ジャッジ時間 6,427 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 5 MLE * 3 -- * 4
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import array
import gc

def solve():
    S_str = sys.stdin.readline().strip()
    T_str = sys.stdin.readline().strip()

    class PalindromicTree:
        def __init__(self, n_max_len, alphabet_size=26, char_offset=ord('A')):
            # n_max_len is the maximum length of the string this tree will process
            self.alphabet_size = alphabet_size
            self.char_offset = char_offset

            # s[0] is a dummy char, string is s[1...s_len]
            self.s = array.array('b', [-1] * (n_max_len + 1)) 
            self.s_len = 0
            
            # Max nodes: n_max_len (distinct palindromes) + 2 roots
            # Array sizes should accommodate this.
            num_potential_nodes = n_max_len + 2
            
            self.len = array.array('i', [0] * num_potential_nodes)
            self.link = array.array('i', [0] * num_potential_nodes)
            
            self.next_transitions = array.array('i', [0] * (num_potential_nodes * alphabet_size))

            self.count = array.array('l', [0] * num_potential_nodes) 
            
            self.node_hash1 = array.array('L', [0] * num_potential_nodes)
            self.node_hash2 = array.array('L', [0] * num_potential_nodes)

            self.nodes_count = 0
            self.last = 0 

            self.H_BASE1, self.H_MOD1 = 31, 10**9 + 7
            self.H_BASE2, self.H_MOD2 = 37, 10**9 + 9 
            
            self.h_powers1 = array.array('L', [1] * (n_max_len + 1))
            self.h_powers2 = array.array('L', [1] * (n_max_len + 1))
            for i in range(1, n_max_len + 1):
                self.h_powers1[i] = (self.h_powers1[i-1] * self.H_BASE1) % self.H_MOD1
                self.h_powers2[i] = (self.h_powers2[i-1] * self.H_BASE2) % self.H_MOD2
            
            self.nodes_count = 1 
            self.len[0] = -1
            self.link[0] = 0 
            
            self.nodes_count = 2
            self.len[1] = 0
            self.link[1] = 0 
            self.last = 1 

        def _get_next(self, node_idx, char_code):
            return self.next_transitions[node_idx * self.alphabet_size + char_code]

        def _set_next(self, node_idx, char_code, next_node_idx):
            self.next_transitions[node_idx * self.alphabet_size + char_code] = next_node_idx

        def _get_char_code(self, char_str_val):
            return ord(char_str_val) - self.char_offset

        def add_char(self, char_val_str):
            c = self._get_char_code(char_val_str)
            self.s_len += 1
            self.s[self.s_len] = c

            cur = self.last
            while self.s[self.s_len - self.len[cur] - 1] != c:
                cur = self.link[cur]
            
            if self._get_next(cur, c) == 0: 
                v_new = self.nodes_count
                self.nodes_count += 1
                
                self.len[v_new] = self.len[cur] + 2
                
                val_c = c + 1 
                if self.len[v_new] == 1: 
                    self.node_hash1[v_new] = val_c
                    self.node_hash2[v_new] = val_c
                else:
                    len_p_cur = self.len[cur] 
                    h1_cur = self.node_hash1[cur]
                    h2_cur = self.node_hash2[cur]

                    self.node_hash1[v_new] = val_c
                    self.node_hash1[v_new] = (self.node_hash1[v_new] + h1_cur * self.H_BASE1) % self.H_MOD1
                    self.node_hash1[v_new] = (self.node_hash1[v_new] + val_c * self.h_powers1[len_p_cur + 1]) % self.H_MOD1
                    
                    self.node_hash2[v_new] = val_c
                    self.node_hash2[v_new] = (self.node_hash2[v_new] + h2_cur * self.H_BASE2) % self.H_MOD2
                    self.node_hash2[v_new] = (self.node_hash2[v_new] + val_c * self.h_powers2[len_p_cur + 1]) % self.H_MOD2
                
                link_cand = self.link[cur]
                while self.s[self.s_len - self.len[link_cand] - 1] != c:
                    link_cand = self.link[link_cand]
                
                if self.len[v_new] == 1:
                    self.link[v_new] = 1 
                else:
                    self.link[v_new] = self._get_next(link_cand, c)
                
                self._set_next(cur, c, v_new)
            
            self.last = self._get_next(cur, c)
            self.count[self.last] += 1

        def finalize_counts(self):
            for i in range(self.nodes_count - 1, 1, -1): 
                self.count[self.link[i]] += self.count[i]

        def get_palindromes_list(self):
            pal_list = []
            for i in range(2, self.nodes_count):
                if self.count[i] > 0: 
                    h_pair = (self.node_hash1[i], self.node_hash2[i])
                    pal_list.append( (h_pair, self.count[i]) )
            return pal_list

    pt_S = PalindromicTree(len(S_str))
    for char_s in S_str:
        pt_S.add_char(char_s)
    pt_S.finalize_counts()
    list_S = pt_S.get_palindromes_list()
    del pt_S 
    gc.collect()

    pt_T = PalindromicTree(len(T_str))
    for char_t in T_str:
        pt_T.add_char(char_t)
    pt_T.finalize_counts()
    list_T = pt_T.get_palindromes_list()
    del pt_T 
    gc.collect()
    
    list_S.sort()
    list_T.sort()

    total_ans = 0
    ptr_S, ptr_T = 0, 0
    len_list_S, len_list_T = len(list_S), len(list_T)

    while ptr_S < len_list_S and ptr_T < len_list_T:
        hash_S_pair, count_S_val = list_S[ptr_S]
        hash_T_pair, count_T_val = list_T[ptr_T]

        if hash_S_pair == hash_T_pair:
            total_ans += count_S_val * count_T_val
            ptr_S += 1
            ptr_T += 1
        elif hash_S_pair < hash_T_pair:
            ptr_S += 1
        else: 
            ptr_T += 1
            
    sys.stdout.write(str(total_ans) + "\n")

solve()
0