結果

問題 No.3351 Towering Tower
コンテスト
ユーザー tassei903
提出日時 2025-11-11 07:39:17
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,531 ms / 3,000 ms
コード長 4,350 bytes
コンパイル時間 336 ms
コンパイル使用メモリ 82,344 KB
実行使用メモリ 88,912 KB
最終ジャッジ日時 2025-11-13 21:17:33
合計ジャッジ時間 29,728 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = lambda: sys.stdin.readline().rstrip()

from collections import deque

def solve(n, h, g):
    inf = 10 ** 18
    idx = sorted(range(n), key=lambda x: h[x])
    h = h + [-1]
    def check(x):
        h[n] = x
        N = 2 * (n + 1)
        """
        h[i] <= x の頂点の最短距離
        """
        dist = [-1] * N
        from collections import deque
        dq = deque()
        for i in g[n]:
            if h[i] <= x:
                dist[i * 2 + 1] = 1
                dq.append(i * 2 + 1)
        dist[n * 2] = 0

        while dq:
            ui, uj = divmod(dq.popleft(), 2)

            vj = uj ^ 1
            for vi in g[ui]:
                if vi == n:
                    continue
                if h[vi] <= x:
                    if h[ui] == h[vi]:
                        if dist[vi * 2 + vj] == -1:
                            dist[vi * 2 + vj] = dist[ui * 2 + uj] + 1
                            dq.append(vi * 2 + vj)
                    elif h[ui] < h[vi]:
                        if dist[vi * 2 + vj] == -1:
                            dist[vi * 2 + vj] = dist[ui * 2 + uj] + 1
                            dq.append(vi * 2 + vj)
                    else:
                        L = (x - h[ui]) // (h[ui] - h[vi])
                        # print(L)
                        if L >= dist[ui * 2 + uj]:
                            if dist[vi * 2 + vj] == -1:
                                dist[vi * 2 + vj] = dist[ui * 2 + uj] + 1
                                dq.append(vi * 2 + vj)
        
        final = dist[:]

        for ui in range(n):
            if h[ui] > x:
                continue
            for uj in range(2):
                if dist[ui * 2 + uj] == -1:
                    continue
                for vi in g[ui]:
                    if h[vi] > x:
                        continue
                    if vi == n:
                        continue
                    vj = uj ^ 1
                    # ui -> vi にできるだけ往復する
                    d1 = inf
                    d2 = inf
                    if h[ui] < h[vi]:
                        d2 = (x - h[vi]) // (h[vi] - h[ui])
                    elif h[ui] > h[vi]:
                        d1 = (x - h[ui]) // (h[ui] - h[vi])
                    
                    if d1 >= dist[ui * 2 + uj]:
                        final[vi * 2 + vj] = max(final[vi * 2 + vj], dist[ui * 2 + uj] + 1)

                    if d2 == inf and d1 == inf:
                        k = inf
                    else:
                        k = min((d1 - uj) // 2, (d2 - uj + 1) // 2)
                    
                    if uj + 2 * k >= dist[ui * 2 + uj]:
                        final[vi * 2 + vj] = max(final[vi * 2 + vj], min(uj + 2 * k + 1, inf))
                    

        def transition(ui):
            for uj in range(2):
                if final[ui * 2 + uj] == -1:
                    continue
                for vi in g[ui]:
                    if vi == n:
                        continue
                    # if h[vi] <= x:
                    #     continue
                    if h[ui] >= h[vi]:
                        continue
                    vj = uj ^ 1
                    d = (h[ui] - x - 1) // (h[vi] - h[ui]) + 1
                    if final[ui * 2 + uj] != -1 and final[ui * 2 + uj] >= d:
                        # print((ui+1, uj), "->", (vi+1, vj), d)
                        final[vi * 2 + vj] = max(final[vi * 2 + vj], min(inf, final[ui * 2 + uj] + 1))
                   
        for ui in idx:
            if h[ui] > x:
                break
            transition(ui)
        transition(n)
        for ui in idx:
            if h[ui] <= x:
                continue
            transition(ui)
        
        return all(final[i * 2] >= 0 or final[i * 2 + 1] >= 0 for i in range(n))

    ok = max(h) * n
    ng = -1

    while ok - ng > 1:
        mid = (ok + ng) // 2
        res = check(mid)
        if res:
            ok = mid
        else:
            ng = mid
    return ok

for _ in range(int(input())):
    n, m = map(int, input().split())
    h = list(map(int, input().split()))
    g = [[] for _ in range(n + 1)]
    for _ in range(m):
        u, v = map(int, input().split())
        u -= 1
        v -= 1
        g[u].append(v)
        g[v].append(u)
    print(solve(n, h, g))
0