結果
問題 | No.1526 Sum of Mex 2 |
ユーザー |
|
提出日時 | 2021-04-30 17:18:59 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 869 ms / 3,000 ms |
コード長 | 8,443 bytes |
コンパイル時間 | 361 ms |
コンパイル使用メモリ | 82,560 KB |
実行使用メモリ | 116,808 KB |
最終ジャッジ日時 | 2024-11-08 23:59:02 |
合計ジャッジ時間 | 14,243 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 32 |
ソースコード
pINF = 10 ** 18nINF = -10 ** 18class SegmentTreeBeats():def __init__(self, n):self.n = nself.log = (n - 1).bit_length()self.size = 1 << self.logself.fmax = [nINF] * (2 * self.size)self.fmin = [pINF] * (2 * self.size)self.smax = [nINF] * (2 * self.size)self.smin = [pINF] * (2 * self.size)self.maxc = [0] * (2 * self.size)self.minc = [0] * (2 * self.size)self.sum = [0] * (2 * self.size)self.add = [0] * (2 * self.size)self.upd = [pINF] * (2 * self.size)self.up = []self.down = []self.lt = [0] * (2 * self.size)self.rt = [0] * (2 * self.size)for i in range(self.size):self.lt[self.size + i] = iself.rt[self.size + i] = i + 1for i in range(self.size)[::-1]:self.lt[i] = self.lt[i << 1]self.rt[i] = self.rt[(i << 1) + 1]def build(self, arr):for i, a in enumerate(arr):self.fmax[self.size + i] = aself.fmin[self.size + i] = aself.maxc[self.size + i] = 1self.minc[self.size + i] = 1self.sum[self.size + i] = afor i in range(1, self.size)[::-1]: self.merge(i)def merge(self, k):self.sum[k] = self.sum[2 * k] + self.sum[2 * k + 1]if self.fmax[2 * k] < self.fmax[2 * k + 1]:self.fmax[k] = self.fmax[2 * k + 1]self.maxc[k] = self.maxc[2 * k + 1]self.smax[k] = max(self.fmax[2 * k], self.smax[2 * k + 1])elif self.fmax[2 * k] > self.fmax[2 * k + 1]:self.fmax[k] = self.fmax[2 * k]self.maxc[k] = self.maxc[2 * k]self.smax[k] = max(self.smax[2 * k], self.fmax[2 * k + 1])else:self.fmax[k] = self.fmax[2 * k]self.maxc[k] = self.maxc[2 * k] + self.maxc[2 * k + 1]self.smax[k] = max(self.smax[2 * k], self.smax[2 * k + 1])if self.fmin[2 * k] > self.fmin[2 * k + 1]:self.fmin[k] = self.fmin[2 * k + 1]self.minc[k] = self.minc[2 * k + 1]self.smin[k] = min(self.fmin[2 * k], self.smin[2 * k + 1])elif self.fmin[2 * k] < self.fmin[2 * k + 1]:self.fmin[k] = self.fmin[2 * k]self.minc[k] = self.minc[2 * k]self.smin[k] = min(self.smin[2 * k], self.fmin[2 * k + 1])else:self.fmin[k] = self.fmin[2 * k]self.minc[k] = self.minc[2 * k] + self.minc[2 * k + 1]self.smin[k] = min(self.smin[2 * k], self.smin[2 * k + 1])def propagate(self, k):if self.size <= k: return # ?if self.upd[k] != pINF:self.update_(2 * k, self.upd[k])self.update_(2 * k + 1, self.upd[k])self.upd[k] = pINFreturnif self.add[k]:self.add_(2 * k, self.add[k])self.add_(2 * k + 1, self.add[k])self.add[k] = 0if self.fmax[k] < self.fmax[2 * k]:self.chmax_(2 * k, self.fmax[k])if self.fmin[2 * k] < self.fmin[k]:self.chmin_(2 * k, self.fmin[k])if self.fmax[k] < self.fmax[2 * k + 1]:self.chmax_(2 * k + 1, self.fmax[k])if self.fmin[2 * k + 1] < self.fmin[k]:self.chmin_(2 * k + 1, self.fmin[k])def up_merge(self):while self.up:self.merge(self.up.pop())def down_propagate(self, k):self.propagate(k)self.down.append(2 * k)self.down.append(2 * k + 1)def update_(self, k, x):self.fmax[k] = xself.smax[k] = nINFself.fmin[k] = xself.smin[k] = pINFself.maxc[k] = self.rt[k] - self.lt[k]self.minc[k] = self.rt[k] - self.lt[k]self.sum[k] = x * (self.rt[k] - self.lt[k])self.add[k] = 0self.upd[k] = xdef add_(self, k, x):self.fmax[k] += xif self.smax[k] != nINF: self.smax[k] += xself.fmin[k] += xif self.smin[k] != pINF: self.smin[k] += xself.sum[k] += x * (self.rt[k] - self.lt[k])if self.upd[k] != pINF:self.upd[k] += xelse:self.add[k] += xdef chmax_(self, k, x):self.sum[k] += (x - self.fmax[k]) * self.maxc[k]if self.fmax[k] == self.fmin[k]:self.fmax[k] = xself.fmin[k] = xelif self.fmax[k] == self.smin[k]:self.fmax[k] = xself.smin[k] = xelse:self.fmax[k] = xif self.upd[k] != pINF and x < self.upd[k]:self.upd[k] = xdef chmin_(self, k, x):self.sum[k] += (x - self.fmin[k]) * self.minc[k]if self.fmin[k] == self.fmax[k]:self.fmin[k] = xself.fmax[k] = xelif self.fmin[k] == self.smax[k]:self.fmin[k] = xself.smax[k] = xelse:self.fmin[k] = xif self.upd[k] != pINF and self.upd[k] < x:self.upd[k] = xdef range_chmax(self, a, b, x):self.down.append(1)while self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a or x <= self.fmin[k]: continueif a <= self.lt[k] and self.rt[k] <= b and x < self.smin[k]:self.chmin_(k, x)continueself.down_propagate(k)self.up.append(k)self.up_merge()def range_chmin(self, a, b, x):self.down.append(1)while self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a or self.fmax[k] <= x: continueif a <= self.lt[k] and self.rt[k] <= b and self.smax[k] < x:self.chmax_(k, x)continueself.down_propagate(k)self.up.append(k)self.up_merge()def range_add(self, a, b, x):self.down.append(1)while self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a: continueif a <= self.lt[k] and self.rt[k] <= b:self.add_(k, x)continueself.down_propagate(k)self.up.append(k)self.up_merge()def range_update(self, a, b, x):self.down.append(1)while self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a: continueif a <= self.lt[k] and self.rt[k] <= b:self.update_(k, x)continueself.down_propagate(k)self.up.append(k)self.up_merge()def get_max(self, a, b):self.down.append(1)v = nINFwhile self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a: continueif a <= self.lt[k] and self.rt[k] <= b:v = max(v, self.fmax[k])continueself.down_propagate(k)return vdef get_min(self, a, b):self.down.append(1)v = pINFwhile self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a: continueif a <= self.lt[k] and self.rt[k] <= b:v = min(v, self.fmin[k])continueself.down_propagate(k)return vdef get_sum(self, a, b):self.down.append(1)v = 0while self.down:k = self.down.pop()if b <= self.lt[k] or self.rt[k] <= a: continueif a <= self.lt[k] and self.rt[k] <= b:v += self.sum[k]continueself.down_propagate(k)return vif __name__ == '__main__':n = int(input())a = list(map(int, input().split()))h = [[n] for i in range(n + 1)]h[0].append(0)for i in range(n):h[a[i]].append(i)for i in range(n + 1):h[i].sort()h[i].reverse()ans = 0st = SegmentTreeBeats(n + 1)st.build([0] * (n + 1))for i in range(n + 1):p = h[i][len(h[i]) - 1]h[i].pop()st.range_chmax(i, n + 2, p)for i in range(n):ans += n * (n + 1) - st.get_sum(0, n + 2)ai = a[i]nxt = h[ai][len(h[ai]) - 1]h[ai].pop()st.range_chmax(ai, n + 2, nxt)st.range_chmax(0, n + 2, i + 1)print(ans)