結果

問題 No.1526 Sum of Mex 2
ユーザー lam6er
提出日時 2025-04-15 23:21:02
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,226 bytes
コンパイル時間 489 ms
コンパイル使用メモリ 81,768 KB
実行使用メモリ 108,808 KB
最終ジャッジ日時 2025-04-15 23:23:06
合計ジャッジ時間 7,526 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 30 TLE * 1 -- * 1
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    sys.setrecursionlimit(1 << 25)
    n, *rest = map(int, sys.stdin.read().split())
    a = rest[:n]

    pos = defaultdict(list)
    for i, x in enumerate(a):
        pos[x].append(i)

    max_x = 0
    while True:
        if (max_x + 1) in pos and len(pos[max_x + 1]) > 0:
            max_x += 1
        else:
            break

    total = 0

    # For x in 1..max_x + 1
    for x in range(1, max_x + 2):
        required = x - 1
        if required == 0:
            # All subarrays that do not contain 1
            # Which is the total subarrays minus those containing 1
            if 1 not in pos or len(pos[1]) == 0:
                count = n * (n + 1) // 2
                total += x * count
                continue
            else:
                # Split the array into segments not containing 1
                prev = -1
                cnt = 0
                for p in pos[1]:
                    left = prev + 1
                    right = p - 1
                    if left <= right:
                        length = right - left + 1
                        cnt += length * (length + 1) // 2
                    prev = p
                # After the last occurrence of 1
                left = prev + 1
                right = n - 1
                if left <= right:
                    length = right - left + 1
                    cnt += length * (length + 1) // 2
                total += x * cnt
            continue

        # Check if all 1..required are present
        valid = True
        for i in range(1, required + 1):
            if i not in pos or len(pos[i]) == 0:
                valid = False
                break
        if not valid:
            continue

        # Now, x is valid. Compute the segments where x is not present
        if x not in pos or len(pos[x]) == 0:
            # The entire array is the segment
            # Compute the number of subarrays containing all 1..required
            cnt = 0
            freq = defaultdict(int)
            left = 0
            missing = required
            current_missing = required
            have = set()
            for right in range(n):
                num = a[right]
                if 1 <= num <= required:
                    if num not in have:
                        current_missing -= 1
                        have.add(num)
                while current_missing == 0:
                    cnt += n - right
                    num_left = a[left]
                    if 1 <= num_left <= required:
                        have.remove(num_left)
                        current_missing += 1
                        # Check if there are more occurrences
                        found = False
                        for i in range(left + 1, right + 1):
                            if 1 <= a[i] <= required and a[i] == num_left:
                                found = True
                                break
                        if not found:
                            pass
                        else:
                            current_missing -= 1
                            have.add(num_left)
                    left += 1
            total += x * cnt
            continue

        # Split the array into segments where x is not present
        segments = []
        prev = -1
        for p in pos[x]:
            if prev + 1 <= p - 1:
                segments.append((prev + 1, p - 1))
            prev = p
        if prev + 1 <= n - 1:
            segments.append((prev + 1, n - 1))

        cnt = 0
        for seg_start, seg_end in segments:
            if seg_start > seg_end:
                continue
            # Compute the number of subarrays in this segment that contain all 1..required
            # Using sliding window
            freq = defaultdict(int)
            left = seg_start
            current_missing = required
            have = set()
            temp_cnt = 0
            for right in range(seg_start, seg_end + 1):
                num = a[right]
                if 1 <= num <= required:
                    if num not in have:
                        current_missing -= 1
                        have.add(num)
                # Shrink left as much as possible
                while current_missing == 0 and left <= right:
                    temp_cnt += seg_end - right + 1
                    num_left = a[left]
                    if 1 <= num_left <= required:
                        have.remove(num_left)
                        current_missing += 1
                        # Check if there are more occurrences in [left+1, right]
                        found = False
                        for i in range(left + 1, right + 1):
                            if 1 <= a[i] <= required and a[i] == num_left:
                                found = True
                                break
                        if found:
                            have.add(num_left)
                            current_missing -= 1
                        else:
                            pass
                    left += 1
            cnt += temp_cnt
        total += x * cnt

    print(total)

if __name__ == '__main__':
    main()
0