結果

問題 No.962 LCPs
ユーザー lam6er
提出日時 2025-03-20 20:25:39
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 125 ms / 2,000 ms
コード長 1,974 bytes
コンパイル時間 274 ms
コンパイル使用メモリ 82,640 KB
実行使用メモリ 114,692 KB
最終ジャッジ日時 2025-03-20 20:27:06
合計ジャッジ時間 6,182 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 64
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    n = int(sys.stdin.readline())
    s = [sys.stdin.readline().strip() for _ in range(n)]
    if n == 0:
        print(0)
        return
    sum1 = sum(len(x) for x in s)
    if n == 1:
        print(sum1)
        return
    m = n - 1
    h = []
    for i in range(m):
        a = s[i]
        b = s[i+1]
        min_len = min(len(a), len(b))
        lcp = 0
        while lcp < min_len and a[lcp] == b[lcp]:
            lcp += 1
        h.append(lcp)
    
    # Compute left boundaries
    left = [-1] * m
    stack = []
    for i in range(m):
        while stack and h[stack[-1]] >= h[i]:
            stack.pop()
        if stack:
            left[i] = stack[-1]
        else:
            left[i] = -1
        stack.append(i)
    
    # Compute right boundaries
    right = [m] * m
    stack = []
    for i in range(m-1, -1, -1):
        while stack and h[stack[-1]] > h[i]:
            stack.pop()
        if stack:
            right[i] = stack[-1]
        else:
            right[i] = m
        stack.append(i)
    
    sum2 = 0
    for k in range(m):
        current_h = h[k]
        l = left[k]
        r = right[k]
        cnt_left = k - l
        cnt_right = r - k
        
        # Calculate sum_y: sum of y from k to r-1 (inclusive)
        lower = k
        upper = r - 1
        num_terms = upper - lower + 1
        if num_terms <= 0:
            sum_y = 0
        else:
            sum_y = (lower + upper) * num_terms // 2
        
        # Calculate sum_x: sum of x from l+1 to k (inclusive)
        lower_x = l + 1
        upper_x = k
        num_terms_x = upper_x - lower_x + 1
        if num_terms_x <= 0:
            sum_x = 0
        else:
            sum_x = (lower_x + upper_x) * num_terms_x // 2
        
        term1 = cnt_left * sum_y
        term2 = cnt_right * sum_x
        term = (term1 - term2) + 2 * cnt_left * cnt_right
        sum2 += current_h * term
    print(sum1 + sum2)

if __name__ == "__main__":
    main()
0