結果

問題 No.1031 いたずら好きなお姉ちゃん
ユーザー qwewe
提出日時 2025-05-14 13:21:37
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,191 bytes
コンパイル時間 174 ms
コンパイル使用メモリ 82,088 KB
実行使用メモリ 79,336 KB
最終ジャッジ日時 2025-05-14 13:24:10
合計ジャッジ時間 14,667 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 8 TLE * 1 -- * 44
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def solve():
    N = int(sys.stdin.readline())
    p = list(map(int, sys.stdin.readline().split()))

    if N <= 1:
        print(0)
        return

    # Precompute logs for sparse table query
    # logs[i] will store floor(log2(i))
    logs = [0] * (N + 1)
    for i in range(2, N + 1):
        logs[i] = logs[i // 2] + 1
    
    # Max k for sparse table: logs[N]
    # Number of levels in sparse table: logs[N] + 1
    num_levels = logs[N] + 1

    st_min_idx = [[0] * N for _ in range(num_levels)]
    st_max_idx = [[0] * N for _ in range(num_levels)]

    for i in range(N):
        st_min_idx[0][i] = i
        st_max_idx[0][i] = i

    for k in range(1, num_levels):
        # For intervals of length 2^k
        # The second half starts at i + 2^(k-1)
        # The interval ends at i + 2^k - 1
        # So, i + 2^k - 1 < N  => i < N - 2^k + 1
        for i in range(N - (1 << k) + 1):
            # Min
            left_half_min_idx = st_min_idx[k-1][i]
            right_half_min_idx = st_min_idx[k-1][i + (1 << (k-1))]
            if p[left_half_min_idx] <= p[right_half_min_idx]: # In case of tie in value, smaller index is preferred (not strictly necessary for permutations)
                st_min_idx[k][i] = left_half_min_idx
            else:
                st_min_idx[k][i] = right_half_min_idx
            
            # Max
            left_half_max_idx = st_max_idx[k-1][i]
            right_half_max_idx = st_max_idx[k-1][i + (1 << (k-1))]
            if p[left_half_max_idx] >= p[right_half_max_idx]: # In case of tie, smaller index
                st_max_idx[k][i] = left_half_max_idx
            else:
                st_max_idx[k][i] = right_half_max_idx
    
    def query_min_idx_func(l, r): # inclusive [l,r]
        length = r - l + 1
        k = logs[length] # floor(log2(length))
        
        # Compare p[l ... l + 2^k - 1] and p[r - 2^k + 1 ... r]
        idx1 = st_min_idx[k][l]
        idx2 = st_min_idx[k][r - (1 << k) + 1]
        
        if p[idx1] <= p[idx2]:
            return idx1
        else:
            return idx2

    def query_max_idx_func(l, r): # inclusive [l,r]
        length = r - l + 1
        k = logs[length]

        idx1 = st_max_idx[k][l]
        idx2 = st_max_idx[k][r - (1 << k) + 1]

        if p[idx1] >= p[idx2]:
            return idx1
        else:
            return idx2

    distinct_swapped_idx_pairs = set()

    for l_idx in range(N):
        for r_idx in range(l_idx + 1, N): # ensures length r_idx - l_idx + 1 >= 2
            min_idx_in_range = query_min_idx_func(l_idx, r_idx)
            max_idx_in_range = query_max_idx_func(l_idx, r_idx)
            
            # Elements are distinct in a permutation, and length >= 2,
            # so min_idx_in_range will not be equal to max_idx_in_range.
            
            # Store canonical form of pair (sorted tuple)
            if min_idx_in_range < max_idx_in_range:
                pair = (min_idx_in_range, max_idx_in_range)
            else:
                pair = (max_idx_in_range, min_idx_in_range)
            distinct_swapped_idx_pairs.add(pair)
            
    sys.stdout.write(str(len(distinct_swapped_idx_pairs)) + "\n")

solve()
0