結果

問題 No.465 PPPPPPPPPPPPPPPPAPPPPPPPP
ユーザー qwewe
提出日時 2025-05-14 13:25:30
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,161 bytes
コンパイル時間 328 ms
コンパイル使用メモリ 82,368 KB
実行使用メモリ 110,920 KB
最終ジャッジ日時 2025-05-14 13:26:27
合計ジャッジ時間 4,125 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 2 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def solve():
    S = sys.stdin.readline().strip()
    N = len(S)

    # Hashing for O(1) palindrome check
    BASE1, MOD1 = 31, 10**9 + 7
    BASE2, MOD2 = 37, 10**9 + 9 # Common primes for hashing

    pow1 = [1] * (N + 1)
    pow2 = [1] * (N + 1)
    for i in range(1, N + 1):
        pow1[i] = (pow1[i-1] * BASE1) % MOD1
        pow2[i] = (pow2[i-1] * BASE2) % MOD2

    hash_s_val = [[0] * (N + 1) for _ in range(2)]
    s_rev = S[::-1]
    hash_s_rev_val = [[0] * (N + 1) for _ in range(2)]

    for i in range(N):
        char_val = ord(S[i]) - ord('a') + 1
        hash_s_val[0][i+1] = (hash_s_val[0][i] * BASE1 + char_val) % MOD1
        hash_s_val[1][i+1] = (hash_s_val[1][i] * BASE2 + char_val) % MOD2
        
        char_val_rev = ord(s_rev[i]) - ord('a') + 1
        hash_s_rev_val[0][i+1] = (hash_s_rev_val[0][i] * BASE1 + char_val_rev) % MOD1
        hash_s_rev_val[1][i+1] = (hash_s_rev_val[1][i] * BASE2 + char_val_rev) % MOD2

    memo_is_pal = {}
    def get_substring_hash(h_table, l, r, pow_arr, mod_val): # 0-indexed, inclusive [l, r]
        length = r - l + 1
        # (h_table[r+1] - h_table[l] * pow_arr[length]) % mod_val
        term2 = (h_table[l] * pow_arr[length]) % mod_val
        return (h_table[r+1] - term2 + mod_val) % mod_val

    def is_palindrome(l, r): # 0-indexed, inclusive [l, r]
        if l > r: return False 
        # Length 1 strings are palindromes
        if l == r: return True 
        
        # Memoization for palindrome checks
        state = (l,r)
        if state in memo_is_pal:
            return memo_is_pal[state]
        
        h1 = get_substring_hash(hash_s_val[0], l, r, pow1, MOD1)
        h2 = get_substring_hash(hash_s_val[1], l, r, pow2, MOD2)
        
        rev_l, rev_r = N - 1 - r, N - 1 - l
        h1_rev = get_substring_hash(hash_s_rev_val[0], rev_l, rev_r, pow1, MOD1)
        h2_rev = get_substring_hash(hash_s_rev_val[1], rev_l, rev_r, pow2, MOD2)
        
        res = (h1 == h1_rev and h2 == h2_rev)
        memo_is_pal[state] = res
        return res

    # R[j_start_A]: number of ways to split S[j_start_A...N-1] into A P3
    # A = S[j_start_A...k-1], P3 = S[k...N-1]
    # k (start index of P3) ranges from j_start_A + 1 to N-1
    # R[j_start_A] = sum_{k=j_start_A+1}^{N-1} is_palindrome(k, N-1)
    R = [0] * N 
    if N >= 2: # Need at least 2 chars for AP3 where A and P3 are non-empty
        if is_palindrome(N-1, N-1): # P3 = S[N-1], A = S[N-2]
             R[N-2] = 1
    
    # Iterate j_start_A from N-3 down to 0
    # Smallest j_start_A relevant for final sum is 2.
    # Smallest j_start_A for R array def can be 0.
    for j_start_A in range(N - 3, -1, -1): 
        R[j_start_A] = R[j_start_A+1] # Ways if P3 starts at j_start_A+2 or later
        if is_palindrome(j_start_A+1, N-1): # Add way if P3 = S[j_start_A+1...N-1], A=S[j_start_A]
            R[j_start_A] += 1
    
    # L[len_P1P2]: number of ways to split S[0...len_P1P2-1] into P1 P2
    # P1 = S[0...k-1], P2 = S[k...len_P1P2-1]
    # k (length of P1) ranges from 1 to len_P1P2-1
    # L[len_P1P2] = sum_{k=1}^{len_P1P2-1} (is_palindrome(0, k-1) * is_palindrome(k, len_P1P2-1))
    L = [0] * (N + 1) 

    # Max len_P1P2 is N-2 (when A=S[N-2], P3=S[N-1])
    # Min len_P1P2 is 2 (P1=S[0], P2=S[1])
    for len_P1P2 in range(2, N - 1): 
        current_L_val = 0
        # k_len_P1 is length of P1
        for k_len_P1 in range(1, len_P1P2): 
            # P1 = S[0 ... k_len_P1-1]
            # P2 = S[k_len_P1 ... len_P1P2-1]
            if is_palindrome(0, k_len_P1-1) and is_palindrome(k_len_P1, len_P1P2-1):
                current_L_val += 1
        L[len_P1P2] = current_L_val
        
    ans = 0
    # j_start_A is the start index of A. S[0...j_start_A-1] is P1P2. S[j_start_A...N-1] is AP3.
    # Length of P1P2 is j_start_A. Min length of P1P2 is 2, so j_start_A >= 2.
    # Length of AP3 is N-j_start_A. Min length of AP3 is 2, so N-j_start_A >= 2 => j_start_A <= N-2.
    for j_start_A in range(2, N - 1): # Iterate j_start_A from 2 to N-2
        len_P1P2 = j_start_A
        ans += L[len_P1P2] * R[j_start_A]
        
    sys.stdout.write(str(ans) + "\n")

solve()
0