結果
| 問題 |
No.1031 いたずら好きなお姉ちゃん
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:04:35 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 603 ms / 3,500 ms |
| コード長 | 2,593 bytes |
| コンパイル時間 | 287 ms |
| コンパイル使用メモリ | 82,316 KB |
| 実行使用メモリ | 157,692 KB |
| 最終ジャッジ日時 | 2025-04-16 00:06:53 |
| 合計ジャッジ時間 | 22,079 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 53 |
ソースコード
import bisect
class SegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size <<= 1
self.tree = [[] for _ in range(2 * self.size)]
for i in range(self.n):
self.tree[self.size + i] = [data[i]]
for i in range(self.size - 1, 0, -1):
self.tree[i] = sorted(self.tree[2*i] + self.tree[2*i+1])
def query_less(self, l, r, x):
res = 0
l += self.size
r += self.size + 1 # Convert to [l, r) interval
while l < r:
if l % 2 == 1:
res += bisect.bisect_left(self.tree[l], x)
l += 1
if r % 2 == 1:
r -= 1
res += bisect.bisect_left(self.tree[r], x)
l >>= 1
r >>= 1
return res
def main():
import sys
input = sys.stdin.read
data = input().split()
n = int(data[0])
p = list(map(int, data[1:n+1]))
# Preprocess R_max: next greater to the right
R_max = [n] * n
stack = []
for i in range(n-1, -1, -1):
while stack and p[stack[-1]] <= p[i]:
stack.pop()
R_max[i] = stack[-1] if stack else n
stack.append(i)
# Preprocess R_min: next smaller to the right
R_min = [n] * n
stack = []
for i in range(n-1, -1, -1):
while stack and p[stack[-1]] >= p[i]:
stack.pop()
R_min[i] = stack[-1] if stack else n
stack.append(i)
# Preprocess L_min: previous smaller to the left
L_min = [-1] * n
stack = []
for j in range(n):
while stack and p[stack[-1]] >= p[j]:
stack.pop()
L_min[j] = stack[-1] if stack else -1
stack.append(j)
# Preprocess L_max: previous greater to the left
L_max = [-1] * n
stack = []
for j in range(n):
while stack and p[stack[-1]] <= p[j]:
stack.pop()
L_max[j] = stack[-1] if stack else -1
stack.append(j)
# Build segment trees
st_L_min = SegmentTree(L_min)
st_L_max = SegmentTree(L_max)
ans = 0
for i in range(n):
# Case 1: i is max, j is min, j > i
left = i + 1
right = R_max[i] - 1
if left <= right:
ans += st_L_min.query_less(left, right, i)
# Case 2: i is min, j is max, j > i
left = i + 1
right = R_min[i] - 1
if left <= right:
ans += st_L_max.query_less(left, right, i)
print(ans)
if __name__ == "__main__":
main()
lam6er