結果
問題 |
No.1526 Sum of Mex 2
|
ユーザー |
![]() |
提出日時 | 2025-06-12 15:14:36 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,752 bytes |
コンパイル時間 | 217 ms |
コンパイル使用メモリ | 82,452 KB |
実行使用メモリ | 116,004 KB |
最終ジャッジ日時 | 2025-06-12 15:14:49 |
合計ジャッジ時間 | 8,296 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 30 TLE * 1 -- * 1 |
ソースコード
import bisect def main(): import sys input = sys.stdin.read().split() n = int(input[0]) A = list(map(int, input[1:n+1])) # Compute the mex of the entire array present = set() mex = 1 for num in A: present.add(num) while mex in present: mex += 1 # Precompute the positions for each number from collections import defaultdict pos = defaultdict(list) for idx, num in enumerate(A): pos[num].append(idx) total = 0 for x in range(1, mex + 1): # Get all positions of x, sorted x_pos = pos.get(x, []) regions = [] prev = -1 for p in x_pos: if prev + 1 <= p - 1: regions.append((prev + 1, p - 1)) prev = p if prev != n - 1: regions.append((prev + 1, n - 1)) # Check if x-1 is present in the entire array required = set(range(1, x)) if x > 1 and any(y not in pos or len(pos[y]) == 0 for y in required): continue cnt = 0 for (s, e) in regions: # Check if this region contains all required elements valid = True if x > 1: for y in required: ys = pos[y] idx = bisect.bisect_left(ys, s) if idx >= len(ys) or ys[idx] > e: valid = False break if not valid: continue # Now compute the number of intervals in [s, e] that contain all required elements # Using sliding window if x == 1: # All intervals in this region are valid L = e - s + 1 cnt += L * (L + 1) // 2 continue freq = {} required_count = len(required) current_count = 0 res = 0 left = s for right in range(s, e + 1): num = A[right] if num in required: if num not in freq or freq[num] == 0: current_count += 1 freq[num] = freq.get(num, 0) + 1 while current_count == required_count: res += e - right + 1 # Move left left_num = A[left] if left_num in required: freq[left_num] -= 1 if freq[left_num] == 0: current_count -= 1 left += 1 cnt += res total += x * cnt print(total) if __name__ == "__main__": main()