結果

問題 No.924 紲星
ユーザー lumc_lumc_
提出日時 2019-09-06 00:43:33
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
RE  
実行時間 -
コード長 2,869 bytes
コンパイル時間 96 ms
コンパイル使用メモリ 13,056 KB
実行使用メモリ 49,588 KB
最終ジャッジ日時 2024-09-15 05:31:33
合計ジャッジ時間 6,793 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 32 ms
11,136 KB
testcase_01 AC 33 ms
10,880 KB
testcase_02 AC 31 ms
11,008 KB
testcase_03 AC 54 ms
11,136 KB
testcase_04 AC 39 ms
11,008 KB
testcase_05 AC 67 ms
11,136 KB
testcase_06 AC 52 ms
11,136 KB
testcase_07 AC 40 ms
10,880 KB
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 TLE -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

class BinaryIndexedTree:
    def __init__(self, n):
        self.identity = 0
        self.n = n
        self.data = [0] * n
        self.m = 1
        while(self.m < n):
            self.m <<= 1

    def add(self, i, x):
        assert(0 <= i < self.n)
        i += 1
        while(i <= self.n):
            self.data[i - 1] = self.data[i - 1] + x
            i += i & -i

    def sum(self, i):
        if i < 0:
            return self.identity
        if i >= self.n:
            i = self.n - 1
        i += 1
        s = self.identity
        while(i > 0):
            s = s + self.data[i - 1]
            i -= i & -i
        return s

    def get(self, i):
        return self.sum(i) - self.sum(i - 1)

    def range(self, a, b):
        return self.sum(b) - self.sum(a - 1)

    def lower_bound(self, w):
        i = 0
        k = self.m
        while(k > 0):
            if i + k <= self.n and self.data[i + k - 1] < w:
                i += k
                w -= self.data[i - 1]
            k >>= 1


BIT = BinaryIndexedTree


def main():
    n, q = map(int, input().split())

    assert(1 <= n and n <= int(1e5))
    # assert(1 <= q and q <= int(1e5))

    a = list(map(int, input().split()))
    v = zip(a, range(n))
    v = sorted(v)

    L = [0] * q
    R = [0] * q

    for i in range(q):
        L[i], R[i] = map(int, input().split())
        L[i] -= 1
        R[i] -= 1

    ok = [n-1] * q
    ng = [-1] * q
    if n > 1:
        mid = [[] for _ in range(n)]
        mid[n//2] = list(range(q))
        # パラサーチ O(Q log N) * O(log N)
        rest = q
        while(rest >= 1):
            bit = BIT(n)
            for i in range(n):
                idx = v[i][1]
                bit.add(idx, 1)
                for j in mid[i]:
                    if bit.range(L[j], R[j]) >= (R[j] - L[j] + 2) // 2:
                        ok[j] = i
                    else:
                        ng[j] = i

                    if(abs(ok[j] - ng[j]) > 1):
                        mid[(ok[j] + ng[j])//2].append(j)
                    else:
                        rest -= 1
                mid[i] = []

    qs = [[] for _ in range(n)]
    for i in range(q):
        qs[ok[i]].append(i)
    # print(ok)
    # print(qs)

    d1 = BIT(n)
    d2 = BIT(n)
    ans = [0] * q
    for i in range(n):
        d2.add(i, a[i])
    for i in range(n):
        idx = v[i][1]
        d2.add(idx, -a[idx])
        d1.add(idx, a[idx])
        for j in qs[i]:
            sz = R[j] - L[j] + 1
            ans[j] = - d1.range(L[j], R[j]) + d2.range(L[j], R[j])
            # print()
            # print([d2.get(i) for i in range(n)])
            # print("j")
            # print(j, L[j], R[j])
            # print(ans[j], - d1.range(L[j], R[j]), + d2.range(L[j], R[j]))
            if(sz % 2 == 1):
                ans[j] += a[idx]
    print("\n".join(map(str, ans)))


main()
0