結果

問題 No.1028 闇討ち
ユーザー lam6er
提出日時 2025-03-26 15:53:50
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,984 bytes
コンパイル時間 186 ms
コンパイル使用メモリ 82,312 KB
実行使用メモリ 89,684 KB
最終ジャッジ日時 2025-03-26 15:54:37
合計ジャッジ時間 9,173 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 3 WA * 6 TLE * 2 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    houses = [[] for _ in range(N+1)]
    for i in range(1, N+1):
        row = list(map(int, sys.stdin.readline().split()))
        for j in range(1, N+1):
            k = row[j-1]
            houses[k].append((i, j))
    
    sum_dy = [0] * (N + 1)
    for k in range(1, N+1):
        s = 0
        for (i, j) in houses[k]:
            s += abs(j - 1)
        sum_dy[k] = s
    
    cost = [[0] * (N + 1) for _ in range(N + 1)]
    for k in range(1, N+1):
        events = []
        left_cnt = 0
        left_sum_a = 0
        right_cnt = 0
        right_sum_b = 0
        for (i, j) in houses[k]:
            dy = abs(j - 1)
            a = i - dy
            b = i + dy
            if a >= 1:
                events.append((a + 1, 'left_end', a))
                if 1 <= a:
                    left_cnt += 1
                    left_sum_a += a
            if b <= N:
                events.append((b, 'right_start', b))
                if 1 >= b:
                    right_cnt += 1
                    right_sum_b += b
        events.sort()
        ptr = 0
        for t in range(1, N + 1):
            while ptr < len(events) and events[ptr][0] == t:
                typ = events[ptr][1]
                val = events[ptr][2]
                if typ == 'left_end':
                    left_cnt -= 1
                    left_sum_a -= val
                elif typ == 'right_start':
                    right_cnt += 1
                    right_sum_b += val
                ptr += 1
            sum_max = left_sum_a - left_cnt * t + right_cnt * t - right_sum_b
            cost[k][t] = sum_dy[k] + sum_max
    
    u = [0] * (N + 1)
    v = [0] * (N + 1)
    p = [0] * (N + 1)
    way = [0] * (N + 1)
    minv = [0] * (N + 1)
    used = [False] * (N + 1)
    for i in range(1, N + 1):
        p[0] = i
        j0 = 0
        minv = [float('inf')] * (N + 1)
        used = [False] * (N + 1)
        while True:
            used[j0] = True
            i0 = p[j0]
            delta = float('inf')
            j1 = 0
            for j in range(1, N + 1):
                if not used[j]:
                    cur = cost[i0][j] - u[i0] - v[j]
                    if cur < minv[j]:
                        minv[j] = cur
                        way[j] = j0
                    if minv[j] < delta:
                        delta = minv[j]
                        j1 = j
            if delta == float('inf'):
                break
            for j in range(N + 1):
                if used[j]:
                    u[p[j]] += delta
                    v[j] -= delta
                else:
                    minv[j] -= delta
            j0 = j1
            if p[j0] == 0:
                break
        while True:
            j1 = way[j0]
            p[j0] = p[j1]
            j0 = j1
            if j0 == 0:
                break
    total = -v[0]
    print(total)

if __name__ == '__main__':
    main()
0