結果

問題 No.1154 シュークリームゲーム(Hard)
ユーザー Kiri8128Kiri8128
提出日時 2020-07-13 08:00:35
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 108 ms / 2,000 ms
コード長 6,164 bytes
コンパイル時間 229 ms
コンパイル使用メモリ 82,560 KB
実行使用メモリ 76,672 KB
最終ジャッジ日時 2024-04-24 16:35:48
合計ジャッジ時間 4,524 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 43 ms
54,400 KB
testcase_01 AC 43 ms
54,784 KB
testcase_02 AC 43 ms
54,656 KB
testcase_03 AC 43 ms
54,784 KB
testcase_04 AC 44 ms
54,400 KB
testcase_05 AC 43 ms
54,400 KB
testcase_06 AC 42 ms
54,528 KB
testcase_07 AC 43 ms
54,912 KB
testcase_08 AC 45 ms
54,912 KB
testcase_09 AC 43 ms
54,912 KB
testcase_10 AC 47 ms
54,656 KB
testcase_11 AC 46 ms
54,912 KB
testcase_12 AC 47 ms
55,424 KB
testcase_13 AC 47 ms
55,424 KB
testcase_14 AC 47 ms
56,192 KB
testcase_15 AC 49 ms
55,680 KB
testcase_16 AC 48 ms
55,936 KB
testcase_17 AC 94 ms
76,672 KB
testcase_18 AC 95 ms
76,416 KB
testcase_19 AC 99 ms
76,032 KB
testcase_20 AC 97 ms
76,288 KB
testcase_21 AC 96 ms
76,032 KB
testcase_22 AC 94 ms
75,904 KB
testcase_23 AC 95 ms
76,032 KB
testcase_24 AC 93 ms
76,288 KB
testcase_25 AC 98 ms
76,416 KB
testcase_26 AC 95 ms
75,776 KB
testcase_27 AC 100 ms
76,416 KB
testcase_28 AC 104 ms
76,416 KB
testcase_29 AC 101 ms
76,544 KB
testcase_30 AC 108 ms
76,288 KB
testcase_31 AC 98 ms
76,672 KB
testcase_32 AC 100 ms
76,416 KB
testcase_33 AC 97 ms
76,288 KB
testcase_34 AC 92 ms
76,160 KB
testcase_35 AC 92 ms
76,032 KB
testcase_36 AC 70 ms
68,352 KB
testcase_37 AC 72 ms
68,352 KB
testcase_38 AC 43 ms
54,912 KB
testcase_39 AC 41 ms
54,400 KB
testcase_40 AC 72 ms
68,480 KB
testcase_41 AC 105 ms
76,672 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import deque

class BalancingTree:
    def __init__(self, n):
        self.N = n
        self.offset = 1 << n
        self.root = self.node(1 << n+1, 1 << n+1)
        self.count = 0

    def debug(self):
        def debug_info(nd_):
            # return (nd_.value - 1, nd_.pivot - 1, nd_.left.value - 1 if nd_.left else -1, nd_.right.value - 1 if nd_.right else -1)
            return nd_.value - self.offset
        
        def debug_node(nd):
            re = []
            if nd.left:
                re += debug_node(nd.left)
            if nd.value: re.append(debug_info(nd))
            if nd.right:
                re += debug_node(nd.right)
            return re
        # 出力個数は調整してください
        print("Debug - count =", self.count, "root =", self.root.value - self.offset, debug_node(self.root)[:50])
        
    def append(self, v):
        v += self.offset
        self.count += 1
        nd = self.root
        while True:
            if v == nd.value:
                self.delete(v - self.offset)
                self.count -= 1
                return 0
            else:
                mi, ma = min(v, nd.value), max(v, nd.value)
                if mi < nd.pivot:
                    nd.value = ma
                    if nd.left:
                        nd = nd.left
                        v = mi
                    else:
                        p = nd.pivot
                        nd.left = self.node(mi, p - (p&-p)//2)
                        break
                else:
                    nd.value = mi
                    if nd.right:
                        nd = nd.right
                        v = ma
                    else:
                        p = nd.pivot
                        nd.right = self.node(ma, p + (p&-p)//2)
                        break
    
    def leftmost(self, nd):
        if nd.left: return self.leftmost(nd.left)
        return nd
    
    def rightmost(self, nd):
        if nd.right: return self.rightmost(nd.right)
        return nd
    
    def find_l(self, v): # vより真に小さいやつの中での最大値(なければ-1)
        v += self.offset
        nd = self.root
        prev = 0
        if nd.value < v: prev = nd.value
        while True:
            if v <= nd.value:
                if nd.left:
                    nd = nd.left
                else:
                    return prev - self.offset
            else:
                prev = nd.value
                if nd.right:
                    nd = nd.right
                else:
                    return prev - self.offset
    
    def find_r(self, v): # vより真に大きいやつの中での最小値(なければRoot)
        v += self.offset
        nd = self.root
        prev = 0
        if nd.value > v: prev = nd.value
        while True:
            if v < nd.value:
                prev = nd.value
                if nd.left:
                    nd = nd.left
                else:
                    return prev - self.offset
            else:
                if nd.right:
                    nd = nd.right
                else:
                    return prev - self.offset
    
    @property
    def max(self):
        return self.find_l((1 << self.N + 1) - self.offset)
    
    @property
    def min(self):
        return self.find_r(-self.offset)

    def delete(self, v, nd = None, prev = None): # 値がvのノードがあれば削除(なければ何もしない)
        v += self.offset
        if not nd: nd = self.root
        if not prev: prev = nd
        while v != nd.value:
            prev = nd
            if v <= nd.value:
                if nd.left:
                    nd = nd.left
                else:
                    return
            else:
                if nd.right:
                    nd = nd.right
                else:
                    return
        if (not nd.left) and (not nd.right):
            if nd.value < prev.value:
                self.count -= 1
                prev.left = None
            else:
                self.count -= 1
                prev.right = None
        elif not nd.left:
            if nd.value < prev.value:
                self.count -= 1
                prev.left = nd.right
            else:
                self.count -= 1
                prev.right = nd.right
        elif not nd.right:
            if nd.value < prev.value:
                self.count -= 1
                prev.left = nd.left
            else:
                self.count -= 1
                prev.right = nd.left
        else:
            nd.value = self.leftmost(nd.right).value
            self.delete(nd.value - self.offset, nd.right, nd)
    
    def __contains__(self, v):
        return self.find_r(v - 1) == v

    class node:
        def __init__(self, v, p):
            self.value = v
            self.pivot = p
            self.left = None
            self.right = None

N = int(input())
A = [int(a) for a in input().split()]
E = [[] for _ in range(N)]
for _ in range(N - 1):
    a, b = map(int, input().split())
    E[a-1].append(b-1)
    E[b-1].append(a-1)

X = [[] for i in range(N)]
P = [-1] * N
Q = deque([0])
R = []
while Q:
    i = deque.popleft(Q)
    R.append(i)
    for a in E[i]:
        if a != P[i]:
            P[a] = i
            E[a].remove(i)
            deque.append(Q, a)

def merge(i):
    mac = 0
    maj = -1
    for j in E[i]:
        if D[j].count > mac:
            mac = D[j].count
            maj = j
    if not mac:
        return
    
    D[i] = D[maj]
    for j in E[i]:
        if j == maj:
            continue
        
        while D[j].count:
            mi = D[j].min
            D[i].append(mi)
            D[j].delete(mi)

D = [BalancingTree(50) for _ in range(N)]
ans = 0

for i in R[::-1]:
    merge(i)
    a = A[i]
    f = 1
    while D[i].count:
        ma = D[i].max
        if ma < a and f > 0:
            break
        f *= -1
        a += ma * f
        D[i].delete(ma)
    if f == -1:
        ans += a if N % 2 == 0 else -a
    else:
        D[i].append(a)
f = 1
while D[0].count:
    ma = D[0].max
    D[0].delete(ma)
    ans += ma * f
    f *= -1
print(ans)
0