結果

問題 No.924 紲星
ユーザー lumc_lumc_
提出日時 2019-09-06 00:32:29
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 3,550 ms / 4,000 ms
コード長 2,871 bytes
コンパイル時間 1,437 ms
コンパイル使用メモリ 87,160 KB
実行使用メモリ 265,720 KB
最終ジャッジ日時 2023-10-13 08:21:37
合計ジャッジ時間 29,078 ms
ジャッジサーバーID
(参考情報)
judge15 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 69 ms
71,428 KB
testcase_01 AC 69 ms
71,104 KB
testcase_02 AC 72 ms
71,204 KB
testcase_03 AC 153 ms
77,244 KB
testcase_04 AC 97 ms
77,260 KB
testcase_05 AC 107 ms
77,252 KB
testcase_06 AC 106 ms
77,296 KB
testcase_07 AC 87 ms
76,332 KB
testcase_08 AC 3,347 ms
263,048 KB
testcase_09 AC 3,193 ms
265,720 KB
testcase_10 AC 3,438 ms
263,980 KB
testcase_11 AC 3,340 ms
264,716 KB
testcase_12 AC 3,550 ms
262,300 KB
testcase_13 AC 1,454 ms
165,724 KB
testcase_14 AC 1,287 ms
153,776 KB
testcase_15 AC 1,286 ms
150,640 KB
testcase_16 AC 1,926 ms
212,020 KB
testcase_17 AC 1,812 ms
202,524 KB
testcase_18 AC 71 ms
71,484 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class BinaryIndexedTree:
    def __init__(self, n):
        self.identity = 0
        self.n = n
        self.data = [0] * n
        self.m = 1
        while(self.m < n):
            self.m <<= 1

    def add(self, i, x):
        assert(0 <= i < self.n)
        i += 1
        while(i <= self.n):
            self.data[i - 1] = self.data[i - 1] + x
            i += i & -i

    def sum(self, i):
        if i < 0:
            return self.identity
        if i >= self.n:
            i = self.n - 1
        i += 1
        s = self.identity
        while(i > 0):
            s = s + self.data[i - 1]
            i -= i & -i
        return s

    def get(self, i):
        return self.sum(i) - self.sum(i - 1)

    def range(self, a, b):
        return self.sum(b) - self.sum(a - 1)

    def lower_bound(self, w):
        i = 0
        k = self.m
        while(k > 0):
            if i + k <= self.n and self.data[i + k - 1] < w:
                i += k
                w -= self.data[i - 1]
            k >>= 1


BIT = BinaryIndexedTree


def main():
    n, q = map(int, input().split())

    # assert(1 <= n and n <= int(1e5))
    # assert(1 <= q and q <= int(1e5))

    a = list(map(int, input().split()))
    v = zip(a, range(n))
    v = sorted(v)

    L = [0] * q
    R = [0] * q

    for i in range(q):
        L[i], R[i] = map(int, input().split())
        L[i] -= 1
        R[i] -= 1

    ok = [n-1] * q
    ng = [-1] * q
    if n > 1:
        mid = [[] for _ in range(n)]
        mid[n//2] = list(range(q))
        # パラサーチ O(Q log N) * O(log N)
        rest = q
        while(rest >= 1):
            bit = BIT(n)
            for i in range(n):
                idx = v[i][1]
                bit.add(idx, 1)
                for j in mid[i]:
                    if bit.range(L[j], R[j]) >= (R[j] - L[j] + 2) // 2:
                        ok[j] = i
                    else:
                        ng[j] = i

                    if(abs(ok[j] - ng[j]) > 1):
                        mid[(ok[j] + ng[j])//2].append(j)
                    else:
                        rest -= 1
                mid[i] = []

    qs = [[] for _ in range(n)]
    for i in range(q):
        qs[ok[i]].append(i)
    # print(ok)
    # print(qs)

    d1 = BIT(n)
    d2 = BIT(n)
    ans = [0] * q
    for i in range(n):
        d2.add(i, a[i])
    for i in range(n):
        idx = v[i][1]
        d2.add(idx, -a[idx])
        d1.add(idx, a[idx])
        for j in qs[i]:
            sz = R[j] - L[j] + 1
            ans[j] = - d1.range(L[j], R[j]) + d2.range(L[j], R[j])
            # print()
            # print([d2.get(i) for i in range(n)])
            # print("j")
            # print(j, L[j], R[j])
            # print(ans[j], - d1.range(L[j], R[j]), + d2.range(L[j], R[j]))
            if(sz % 2 == 1):
                ans[j] += a[idx]
    print("\n".join(map(str, ans)))


main()
0