結果

問題 No.1526 Sum of Mex 2
ユーザー Rheo TommyRheo Tommy
提出日時 2021-04-30 17:18:59
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 776 ms / 3,000 ms
コード長 8,443 bytes
コンパイル時間 218 ms
コンパイル使用メモリ 81,972 KB
実行使用メモリ 117,104 KB
最終ジャッジ日時 2024-04-26 09:38:21
合計ジャッジ時間 12,158 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 36 ms
55,524 KB
testcase_01 AC 35 ms
54,532 KB
testcase_02 AC 41 ms
55,464 KB
testcase_03 AC 81 ms
76,264 KB
testcase_04 AC 45 ms
62,860 KB
testcase_05 AC 35 ms
55,004 KB
testcase_06 AC 83 ms
76,492 KB
testcase_07 AC 50 ms
66,040 KB
testcase_08 AC 87 ms
76,188 KB
testcase_09 AC 55 ms
67,540 KB
testcase_10 AC 77 ms
73,964 KB
testcase_11 AC 79 ms
76,772 KB
testcase_12 AC 88 ms
76,480 KB
testcase_13 AC 191 ms
83,476 KB
testcase_14 AC 212 ms
87,004 KB
testcase_15 AC 211 ms
87,352 KB
testcase_16 AC 449 ms
113,612 KB
testcase_17 AC 364 ms
109,532 KB
testcase_18 AC 169 ms
81,772 KB
testcase_19 AC 181 ms
80,680 KB
testcase_20 AC 290 ms
97,000 KB
testcase_21 AC 421 ms
113,324 KB
testcase_22 AC 313 ms
96,892 KB
testcase_23 AC 420 ms
113,684 KB
testcase_24 AC 450 ms
115,188 KB
testcase_25 AC 433 ms
114,132 KB
testcase_26 AC 468 ms
115,100 KB
testcase_27 AC 471 ms
114,712 KB
testcase_28 AC 451 ms
114,700 KB
testcase_29 AC 448 ms
114,640 KB
testcase_30 AC 477 ms
115,436 KB
testcase_31 AC 456 ms
114,916 KB
testcase_32 AC 472 ms
113,408 KB
testcase_33 AC 776 ms
117,104 KB
evil_largest AC 1,924 ms
225,720 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

pINF = 10 ** 18
nINF = -10 ** 18


class SegmentTreeBeats():
    def __init__(self, n):
        self.n = n
        self.log = (n - 1).bit_length()
        self.size = 1 << self.log
        self.fmax = [nINF] * (2 * self.size)
        self.fmin = [pINF] * (2 * self.size)
        self.smax = [nINF] * (2 * self.size)
        self.smin = [pINF] * (2 * self.size)
        self.maxc = [0] * (2 * self.size)
        self.minc = [0] * (2 * self.size)
        self.sum = [0] * (2 * self.size)
        self.add = [0] * (2 * self.size)
        self.upd = [pINF] * (2 * self.size)
        self.up = []
        self.down = []
        self.lt = [0] * (2 * self.size)
        self.rt = [0] * (2 * self.size)
        for i in range(self.size):
            self.lt[self.size + i] = i
            self.rt[self.size + i] = i + 1
        for i in range(self.size)[::-1]:
            self.lt[i] = self.lt[i << 1]
            self.rt[i] = self.rt[(i << 1) + 1]

    def build(self, arr):
        for i, a in enumerate(arr):
            self.fmax[self.size + i] = a
            self.fmin[self.size + i] = a
            self.maxc[self.size + i] = 1
            self.minc[self.size + i] = 1
            self.sum[self.size + i] = a
        for i in range(1, self.size)[::-1]: self.merge(i)

    def merge(self, k):
        self.sum[k] = self.sum[2 * k] + self.sum[2 * k + 1]
        if self.fmax[2 * k] < self.fmax[2 * k + 1]:
            self.fmax[k] = self.fmax[2 * k + 1]
            self.maxc[k] = self.maxc[2 * k + 1]
            self.smax[k] = max(self.fmax[2 * k], self.smax[2 * k + 1])
        elif self.fmax[2 * k] > self.fmax[2 * k + 1]:
            self.fmax[k] = self.fmax[2 * k]
            self.maxc[k] = self.maxc[2 * k]
            self.smax[k] = max(self.smax[2 * k], self.fmax[2 * k + 1])
        else:
            self.fmax[k] = self.fmax[2 * k]
            self.maxc[k] = self.maxc[2 * k] + self.maxc[2 * k + 1]
            self.smax[k] = max(self.smax[2 * k], self.smax[2 * k + 1])
        if self.fmin[2 * k] > self.fmin[2 * k + 1]:
            self.fmin[k] = self.fmin[2 * k + 1]
            self.minc[k] = self.minc[2 * k + 1]
            self.smin[k] = min(self.fmin[2 * k], self.smin[2 * k + 1])
        elif self.fmin[2 * k] < self.fmin[2 * k + 1]:
            self.fmin[k] = self.fmin[2 * k]
            self.minc[k] = self.minc[2 * k]
            self.smin[k] = min(self.smin[2 * k], self.fmin[2 * k + 1])
        else:
            self.fmin[k] = self.fmin[2 * k]
            self.minc[k] = self.minc[2 * k] + self.minc[2 * k + 1]
            self.smin[k] = min(self.smin[2 * k], self.smin[2 * k + 1])

    def propagate(self, k):
        if self.size <= k: return  # ?
        if self.upd[k] != pINF:
            self.update_(2 * k, self.upd[k])
            self.update_(2 * k + 1, self.upd[k])
            self.upd[k] = pINF
            return
        if self.add[k]:
            self.add_(2 * k, self.add[k])
            self.add_(2 * k + 1, self.add[k])
            self.add[k] = 0
        if self.fmax[k] < self.fmax[2 * k]:
            self.chmax_(2 * k, self.fmax[k])
        if self.fmin[2 * k] < self.fmin[k]:
            self.chmin_(2 * k, self.fmin[k])
        if self.fmax[k] < self.fmax[2 * k + 1]:
            self.chmax_(2 * k + 1, self.fmax[k])
        if self.fmin[2 * k + 1] < self.fmin[k]:
            self.chmin_(2 * k + 1, self.fmin[k])

    def up_merge(self):
        while self.up:
            self.merge(self.up.pop())

    def down_propagate(self, k):
        self.propagate(k)
        self.down.append(2 * k)
        self.down.append(2 * k + 1)

    def update_(self, k, x):
        self.fmax[k] = x
        self.smax[k] = nINF
        self.fmin[k] = x
        self.smin[k] = pINF
        self.maxc[k] = self.rt[k] - self.lt[k]
        self.minc[k] = self.rt[k] - self.lt[k]
        self.sum[k] = x * (self.rt[k] - self.lt[k])
        self.add[k] = 0
        self.upd[k] = x

    def add_(self, k, x):
        self.fmax[k] += x
        if self.smax[k] != nINF: self.smax[k] += x
        self.fmin[k] += x
        if self.smin[k] != pINF: self.smin[k] += x
        self.sum[k] += x * (self.rt[k] - self.lt[k])
        if self.upd[k] != pINF:
            self.upd[k] += x
        else:
            self.add[k] += x

    def chmax_(self, k, x):
        self.sum[k] += (x - self.fmax[k]) * self.maxc[k]
        if self.fmax[k] == self.fmin[k]:
            self.fmax[k] = x
            self.fmin[k] = x
        elif self.fmax[k] == self.smin[k]:
            self.fmax[k] = x
            self.smin[k] = x
        else:
            self.fmax[k] = x
        if self.upd[k] != pINF and x < self.upd[k]:
            self.upd[k] = x

    def chmin_(self, k, x):
        self.sum[k] += (x - self.fmin[k]) * self.minc[k]
        if self.fmin[k] == self.fmax[k]:
            self.fmin[k] = x
            self.fmax[k] = x
        elif self.fmin[k] == self.smax[k]:
            self.fmin[k] = x
            self.smax[k] = x
        else:
            self.fmin[k] = x
        if self.upd[k] != pINF and self.upd[k] < x:
            self.upd[k] = x

    def range_chmax(self, a, b, x):
        self.down.append(1)
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a or x <= self.fmin[k]: continue
            if a <= self.lt[k] and self.rt[k] <= b and x < self.smin[k]:
                self.chmin_(k, x)
                continue
            self.down_propagate(k)
            self.up.append(k)
        self.up_merge()

    def range_chmin(self, a, b, x):
        self.down.append(1)
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a or self.fmax[k] <= x: continue
            if a <= self.lt[k] and self.rt[k] <= b and self.smax[k] < x:
                self.chmax_(k, x)
                continue
            self.down_propagate(k)
            self.up.append(k)
        self.up_merge()

    def range_add(self, a, b, x):
        self.down.append(1)
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a: continue
            if a <= self.lt[k] and self.rt[k] <= b:
                self.add_(k, x)
                continue
            self.down_propagate(k)
            self.up.append(k)
        self.up_merge()

    def range_update(self, a, b, x):
        self.down.append(1)
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a: continue
            if a <= self.lt[k] and self.rt[k] <= b:
                self.update_(k, x)
                continue
            self.down_propagate(k)
            self.up.append(k)
        self.up_merge()

    def get_max(self, a, b):
        self.down.append(1)
        v = nINF
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a: continue
            if a <= self.lt[k] and self.rt[k] <= b:
                v = max(v, self.fmax[k])
                continue
            self.down_propagate(k)
        return v

    def get_min(self, a, b):
        self.down.append(1)
        v = pINF
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a: continue
            if a <= self.lt[k] and self.rt[k] <= b:
                v = min(v, self.fmin[k])
                continue
            self.down_propagate(k)
        return v

    def get_sum(self, a, b):
        self.down.append(1)
        v = 0
        while self.down:
            k = self.down.pop()
            if b <= self.lt[k] or self.rt[k] <= a: continue
            if a <= self.lt[k] and self.rt[k] <= b:
                v += self.sum[k]
                continue
            self.down_propagate(k)
        return v


if __name__ == '__main__':
    n = int(input())
    a = list(map(int, input().split()))
    h = [[n] for i in range(n + 1)]
    h[0].append(0)
    for i in range(n):
        h[a[i]].append(i)
    for i in range(n + 1):
        h[i].sort()
        h[i].reverse()

    ans = 0
    st = SegmentTreeBeats(n + 1)
    st.build([0] * (n + 1))
    for i in range(n + 1):
        p = h[i][len(h[i]) - 1]
        h[i].pop()
        st.range_chmax(i, n + 2, p)

    for i in range(n):
        ans += n * (n + 1) - st.get_sum(0, n + 2)
        ai = a[i]
        nxt = h[ai][len(h[ai]) - 1]
        h[ai].pop()
        st.range_chmax(ai, n + 2, nxt)
        st.range_chmax(0, n + 2, i + 1)

    print(ans)
0