結果

問題 No.2826 Earthwork
ユーザー 🦠みどりむし🦠みどりむし
提出日時 2024-06-22 13:58:53
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,706 bytes
コンパイル時間 469 ms
コンパイル使用メモリ 82,380 KB
実行使用メモリ 245,740 KB
最終ジャッジ日時 2024-06-26 11:38:02
合計ジャッジ時間 60,365 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 37 ms
53,504 KB
testcase_01 WA -
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 WA -
testcase_35 WA -
testcase_36 WA -
testcase_37 WA -
testcase_38 WA -
testcase_39 WA -
testcase_40 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from itertools import product

# Constants and type definitions
DIRS4 = [(-1, 0), (0, 1), (1, 0), (0, -1)]
i64 = int

# Main function
def main():
    input = sys.stdin.read
    data = input().split()
    
    index = 0
    
    n = int(data[index])
    index += 1
    
    h = [i64(data[index + i]) for i in range(n * n)]
    index += n * n
    
    s = [data[index + i] for i in range(n)]
    index += n
    
    a = [i64(0)] * (n * n)
    b = [i64(0)] * (n * n)
    
    for i in range(n - 1):
        for j in range(n):
            a[i * n + j] = i64(data[index])
            index += 1

    for i in range(n):
        for j in range(n - 1):
            b[i * n + j] = i64(data[index])
            index += 1
    
    def parity(p):
        return ((p // n) + p % n) % 2
    
    sup = [-(1 << 60)] * (n * n)
    
    for t in range(2):
        def cost(p, q, k):
            r = min(p, q)
            w = a[r] if k % 2 == 0 else b[r]
            x = h[p] + h[q]
            return w * abs(x) + ((parity(p) + t) % 2 * 2 - 1) * x
        
        dist = [1 << 60] * (n * n)
        que = []
        
        for i in range(n * n):
            if s[i // n][i % n] == '=':
                heapq.heappush(que, (0, i))
                dist[i] = 0
            
            if parity(i) == t:
                if s[i // n][i % n] == '+':
                    heapq.heappush(que, (0, i))
                    dist[i] = 0
            else:
                if s[i // n][i % n] == '-':
                    heapq.heappush(que, (0, i))
                    dist[i] = 0
        
        while que:
            d, v = heapq.heappop(que)
            if d > dist[v]:
                continue
            
            i, j = divmod(v, n)
            
            for k in range(4):
                ni, nj = i + DIRS4[k][0], j + DIRS4[k][1]
                if ni < 0 or nj < 0 or ni >= n or nj >= n:
                    continue
                
                nv = ni * n + nj
                nd = d + cost(v, nv, k)
                if nd >= dist[nv]:
                    continue
                
                dist[nv] = nd
                heapq.heappush(que, (nd, nv))
        
        for i in range(n * n):
            v = h[i] + (parity(i) * 2 - 1) * dist[i]
            if sup[i] < v:
                sup[i] = v
    
    q = int(data[index])
    index += 1
    
    results = []
    for _ in range(q):
        r = int(data[index])
        c = int(data[index + 1])
        r -= 1
        c -= 1
        e = i64(data[index + 2])
        index += 3
        
        results.append("Yes" if sup[r * n + c] >= e else "No")
    
    print("\n".join(results))

if __name__ == "__main__":
    main()
0