結果

問題 No.924 紲星
ユーザー lumc_lumc_
提出日時 2019-09-06 00:23:13
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
TLE  
実行時間 -
コード長 2,871 bytes
コンパイル時間 233 ms
コンパイル使用メモリ 11,132 KB
実行使用メモリ 8,732 KB
最終ジャッジ日時 2023-10-13 08:19:31
合計ジャッジ時間 6,775 ms
ジャッジサーバーID
(参考情報)
judge11 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 17 ms
8,420 KB
testcase_01 AC 16 ms
8,492 KB
testcase_02 AC 17 ms
8,552 KB
testcase_03 AC 36 ms
8,608 KB
testcase_04 AC 23 ms
8,620 KB
testcase_05 AC 41 ms
8,732 KB
testcase_06 AC 32 ms
8,656 KB
testcase_07 AC 24 ms
8,628 KB
testcase_08 TLE -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

class BinaryIndexedTree:
    def __init__(self, n):
        self.identity = 0
        self.n = n
        self.data = [0] * n
        self.m = 1
        while(self.m < n):
            self.m <<= 1

    def add(self, i, x):
        assert(0 <= i < self.n)
        i += 1
        while(i <= self.n):
            self.data[i - 1] = self.data[i - 1] + x
            i += i & -i

    def sum(self, i):
        if i < 0:
            return self.identity
        if i >= self.n:
            i = self.n - 1
        i += 1
        s = self.identity
        while(i > 0):
            s = s + self.data[i - 1]
            i -= i & -i
        return s

    def get(self, i):
        return self.sum(i) - self.sum(i - 1)

    def range(self, a, b):
        return self.sum(b) - self.sum(a - 1)

    def lower_bound(self, w):
        i = 0
        k = self.m
        while(k > 0):
            if i + k <= self.n and self.data[i + k - 1] < w:
                i += k
                w -= self.data[i - 1]
            k >>= 1


BIT = BinaryIndexedTree


def main():
    n, q = map(int, input().split())

    # assert(1 <= n and n <= int(1e5))
    # assert(1 <= q and q <= int(1e5))

    a = list(map(int, input().split()))
    v = zip(a, range(n))
    v = sorted(v)

    L = [0] * q
    R = [0] * q

    for i in range(q):
        L[i], R[i] = map(int, input().split())
        L[i] -= 1
        R[i] -= 1

    ok = [n-1] * q
    ng = [-1] * q
    if n > 1:
        mid = [[] for _ in range(n)]
        mid[n//2] = list(range(q))
        # パラサーチ O(Q log N) * O(log N)
        rest = q
        while(rest >= 1):
            bit = BIT(n)
            for i in range(n):
                idx = v[i][1]
                bit.add(idx, 1)
                for j in mid[i]:
                    if bit.range(L[j], R[j]) >= (R[j] - L[j] + 2) // 2:
                        ok[j] = i
                    else:
                        ng[j] = i

                    if(abs(ok[j] - ng[j]) > 1):
                        mid[(ok[j] + ng[j])//2].append(j)
                    else:
                        rest -= 1
                mid[i] = []

    qs = [[] for _ in range(n)]
    for i in range(q):
        qs[ok[i]].append(i)
    # print(ok)
    # print(qs)

    d1 = BIT(n)
    d2 = BIT(n)
    ans = [0] * q
    for i in range(n):
        d2.add(i, a[i])
    for i in range(n):
        idx = v[i][1]
        d2.add(idx, -a[idx])
        d1.add(idx, a[idx])
        for j in qs[i]:
            sz = R[j] - L[j] + 1
            ans[j] = - d1.range(L[j], R[j]) + d2.range(L[j], R[j])
            # print()
            # print([d2.get(i) for i in range(n)])
            # print("j")
            # print(j, L[j], R[j])
            # print(ans[j], - d1.range(L[j], R[j]), + d2.range(L[j], R[j]))
            if(sz % 2 == 1):
                ans[j] += a[idx]
    print("\n".join(map(str, ans)))


main()
0