結果

問題 No.2375 watasou and hibit's baseball
ユーザー lam6er
提出日時 2025-04-15 23:25:36
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 764 ms / 2,000 ms
コード長 2,514 bytes
コンパイル時間 443 ms
コンパイル使用メモリ 81,596 KB
実行使用メモリ 130,376 KB
最終ジャッジ日時 2025-04-15 23:27:08
合計ジャッジ時間 6,933 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 36
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    n, A, B = map(int, sys.stdin.readline().split())
    coords = []
    K = []
    for _ in range(n):
        x, y, k = map(int, sys.stdin.readline().split())
        coords.append((x, y))
        K.append(k)
    
    # Precompute distance and K differences
    dist = [[0] * n for _ in range(n)]
    k_diff = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            xi, yi = coords[i]
            xj, yj = coords[j]
            dist[i][j] = abs(xi - xj) + abs(yi - yj)
            k_diff[i][j] = abs(K[i] - K[j])
    
    # Initialize DP: mask -> {(prev_prev, prev): max_length}
    dp = defaultdict(dict)
    for i in range(n):
        mask = 1 << i
        dp[mask][(-1, i)] = 1
    
    max_len = 1  # At least one ball can be thrown
    
    # Process masks in order of increasing number of bits
    masks = [mask for mask in range(1, 1 << n)]
    masks.sort(key=lambda x: bin(x).count('1'))
    
    for mask in masks:
        if mask not in dp:
            continue
        # Iterate over a copy to avoid runtime errors due to dict changes
        for (prev_prev, prev) in list(dp[mask].keys()):
            current_len = dp[mask][(prev_prev, prev)]
            # Try adding each possible next ball
            for w in range(n):
                if not (mask & (1 << w)):
                    new_mask = mask | (1 << w)
                    new_length = current_len + 1
                    # Check conditions based on new_length (j)
                    if new_length == 2:
                        # j=2: check condition 2 or 3
                        cond = (dist[prev][w] >= A) or (k_diff[prev][w] >= B)
                    else:
                        # j >=3: check condition 3 or 4
                        # prev_prev is valid (not -1)
                        cond = (k_diff[prev][w] >= B) or (dist[prev][w] + dist[prev_prev][w] >= A)
                    if cond:
                        new_prev_prev = prev
                        new_prev = w
                        key = (new_prev_prev, new_prev)
                        if new_mask not in dp:
                            dp[new_mask] = {}
                        if key not in dp[new_mask] or dp[new_mask].get(key, 0) < new_length:
                            dp[new_mask][key] = new_length
                            if new_length > max_len:
                                max_len = new_length
    print(max_len)

if __name__ == "__main__":
    main()
0