結果
問題 |
No.3078 Difference Sum Query
|
ユーザー |
![]() |
提出日時 | 2025-03-28 22:55:22 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,699 bytes |
コンパイル時間 | 189 ms |
コンパイル使用メモリ | 82,836 KB |
実行使用メモリ | 182,928 KB |
最終ジャッジ日時 | 2025-03-28 22:55:27 |
合計ジャッジ時間 | 4,873 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 6 TLE * 1 -- * 19 |
ソースコード
class segtree(): def __init__(self, n, op, e): self.n = n self.size = 1 << ((n - 1).bit_length()) self.e = e self.seg = [e for _ in range(2 * self.size)] self.op = op def set(self, i, x): i += self.size self.seg[i] = x i >>= 1 while i: self.seg[i] = self.op(self.seg[2 * i], self.seg[2 * i + 1]) i >>= 1 return def get(self, i): return self.seg[i + self.size] def add(self, i, x): self.set(i, self.get(i) + x) return def set_array(self, A): for i in range(self.n): self.set(i, A[i]) return def prod(self, l, r): L, R = self.e, self.e l, r = l + self.size, r + self.size while l < r: if l & 1: L = self.op(L, self.seg[l]) l += 1 if r & 1: r -= 1 R = self.op(R, self.seg[r]) l >>= 1 r >>= 1 return self.op(L, R) def max_right(self, l, check): assert (0 <= l <= self.n) assert (check(self.e) == True) if l == self.n: return self.n l, sm = l + self.size, self.e while True: while l % 2 == 0: l >>= 1 if not(check(self.op(sm, self.seg[l]))): while l < self.size: l <<= 1 if check(self.op(sm, self.seg[l])): sm = self.op(sm, self.seg[l]) l += 1 return l - self.size sm = self.op(sm, self.seg[l]) l += 1 if (l & -l) == l: break return self.n def min_left(self, r, check): assert (0 <= r <= self.n) assert (check(self.e) == True) if r == 0: return 0 r, sm = r + self.size, self.e while True: r -= 1 while (r > 1 and (r % 2)): r >>= 1 if not(check(self.op(self.seg[r], sm))): while (r < self.size): r = (r << 1) + 1 if self.op(sm, self.seg[r]): sm = self.op(sm, self.seg[r]) r -= 1 return r + 1 - self.size sm = self.op(sm, self.seg[r]) if (r & -r) == r: break return 0 def add(x, y): return x + y N, Q = map(int, input().split()) A = list(map(int, input().split())) S = set() query = [] for i in range(Q): l, r, x = map(int, input().split()) query.append((l - 1, r, x)) S.add(x) for a in A: S.add(a) S = list(S) S.sort() d = {} for i in range(len(S)): d[S[i]] = i M = len(S) B = [0 for _ in range(N)] for i in range(N): B[i] = d[A[i]] for i in range(Q): l, r, x = query[i] x = d[x] query[i] = (l, r, x) def Mo_algorithm(N, Query): Q = len(Query) # ブロックの長さを決める,(ブロック順, r, l)の順にソート W = int(max(1, N / (((2 * Q) / 3) ** 0.5))) data = [0 for i in range(Q)] query = [0 for i in range(Q)] X = [0 for _ in range(Q)] for i in range(Q): l, r, x = Query[i] block = l // W data[i] = (l << 20) | r query[i] = (block << 40) | (r << 20) | i if block & 1: query[i] = (block << 40) + ((-r) << 20) + i X[i] = x query.sort() # 必要な初期解,テーブルなど C = segtree(M, add, 0) seg = segtree(M, add, 00) def Mo_add(l): i = B[l] C.add(i, 1) seg.add(i, S[i]) def Mo_del(l): i = B[l] C.add(i, -1) seg.add(i, -S[i]) # query処理 nl, nr = 0, 0 mask = (1 << 20) - 1 ans = [0 for _ in range(Q)] for que in query: i = que & mask l = data[i] >> 20 r = data[i] & mask x = X[i] while nl > l: nl -= 1 Mo_add(nl) while nr < r: Mo_add(nr) nr += 1 while nl < l: Mo_del(nl) nl += 1 while nr > r: nr -= 1 Mo_del(nr) res = seg.prod(x, M) - seg.prod(0, x) - S[x] * (C.prod(x, M) - C.prod(0, x)) ans[i] = res return ans ans = Mo_algorithm(N, query) for res in ans: print(res)