結果
| 問題 | 
                            No.1526 Sum of Mex 2
                             | 
                    
| ユーザー | 
                             gew1fw
                         | 
                    
| 提出日時 | 2025-06-12 18:48:11 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                WA
                                 
                             
                            
                         | 
                    
| 実行時間 | - | 
| コード長 | 1,973 bytes | 
| コンパイル時間 | 281 ms | 
| コンパイル使用メモリ | 82,060 KB | 
| 実行使用メモリ | 114,816 KB | 
| 最終ジャッジ日時 | 2025-06-12 18:48:18 | 
| 合計ジャッジ時間 | 6,562 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge2 / judge3 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 8 WA * 2 TLE * 1 -- * 21 | 
ソースコード
def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    N = int(data[0])
    A = list(map(int, data[1:N+1]))
    
    max_x = max(A) + 2
    pos = {}
    for i, num in enumerate(A):
        if num not in pos:
            pos[num] = []
        pos[num].append(i)
    
    total = 0
    for x in range(1, max_x + 1):
        if x == 1:
            regions = []
            prev = -1
            for i in range(N):
                if A[i] == 1:
                    regions.append((prev + 1, i - 1))
                    prev = i
            regions.append((prev + 1, N - 1))
            
            count = 0
            for l, r in regions:
                if l > r:
                    continue
                k = r - l + 1
                count += k * (k + 1) // 2
            total += 1 * count
        else:
            if x not in pos:
                regions = [(0, N - 1)]
            else:
                regions = []
                prev = -1
                for p in pos[x]:
                    regions.append((prev + 1, p - 1))
                    prev = p
                regions.append((prev + 1, N - 1))
            
            count_x = 0
            for l, r in regions:
                if l > r:
                    continue
                required = set(range(1, x))
                left = l
                right = l
                current = set()
                while right <= r:
                    num = A[right]
                    if num in required:
                        current.add(num)
                    while len(current) == len(required) and left <= right:
                        count_x += (r - right + 1)
                        num_left = A[left]
                        if num_left in required:
                            current.remove(num_left)
                        left += 1
                    right += 1
            total += x * count_x
    print(total)
if __name__ == "__main__":
    main()
            
            
            
        
            
gew1fw