結果

問題 No.1079 まお
ユーザー gew1fw
提出日時 2025-06-12 20:01:27
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,843 bytes
コンパイル時間 277 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 124,636 KB
最終ジャッジ日時 2025-06-12 20:05:08
合計ジャッジ時間 10,143 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 26 TLE * 1 -- * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    sys.setrecursionlimit(1 << 25)
    n, K = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    n = len(A)

    # Precompute PSE and NSE using monotonic stack
    PSE = [-1] * n
    stack = []
    for i in range(n):
        while stack and A[stack[-1]] >= A[i]:
            stack.pop()
        if stack:
            PSE[i] = stack[-1]
        else:
            PSE[i] = -1
        stack.append(i)

    NSE = [n] * n
    stack = []
    for i in range(n-1, -1, -1):
        while stack and A[stack[-1]] >= A[i]:
            stack.pop()
        if stack:
            NSE[i] = stack[-1]
        else:
            NSE[i] = n
        stack.append(i)

    # Precompute indices for each value
    value_indices = defaultdict(list)
    for idx, val in enumerate(A):
        value_indices[val].append(idx)

    # Precompute P_A and N_A for each m
    P_A = [-1] * n
    N_A = [n] * n
    for m in range(n):
        val = A[m]
        lst = value_indices[val]
        # Find previous occurrence before m
        low = 0
        high = len(lst) - 1
        res = -1
        while low <= high:
            mid = (low + high) // 2
            if lst[mid] < m:
                res = lst[mid]
                low = mid + 1
            else:
                high = mid - 1
        P_A[m] = res

        # Find next occurrence after m
        low = 0
        high = len(lst) - 1
        res = n
        while low <= high:
            mid = (low + high) // 2
            if lst[mid] > m:
                res = lst[mid]
                high = mid - 1
            else:
                low = mid + 1
        N_A[m] = res

    total = 0

    for m in range(n):
        # Compute l_start and l_end
        L = PSE[m]
        P = P_A[m]
        l_start = max(L + 1, P + 1)
        l_end = m
        if l_start > l_end:
            continue

        # Compute r_start and r_end
        R = NSE[m]
        N = N_A[m]
        r_start = m
        if N == n:
            r_end = R - 1
        else:
            r_end = min(R - 1, N - 1)
        if r_start > r_end:
            continue

        # Collect all r in [r_start, r_end], and build frequency map
        freq = defaultdict(lambda: [0, 0])  # [count, sum_r]
        for r in range(r_start, r_end + 1):
            val = A[r]
            freq[val][0] += 1
            freq[val][1] += r

        # Iterate over each l in l_range and compute contribution
        for l in range(l_start, l_end + 1):
            current_val = A[l]
            target = K - current_val

            cnt, sum_r = freq.get(target, [0, 0])
            if cnt == 0:
                continue

            contribution = (sum_r + cnt) - l * cnt
            total += contribution

    print(total)

if __name__ == '__main__':
    main()
0