結果

問題 No.5007 Steiner Space Travel
ユーザー jakujaku12jakujaku12
提出日時 2023-04-29 14:05:08
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 986 ms / 1,000 ms
コード長 5,580 bytes
コンパイル時間 721 ms
コンパイル使用メモリ 87,272 KB
実行使用メモリ 82,176 KB
スコア 7,968,764
最終ジャッジ日時 2023-04-29 14:05:41
合計ジャッジ時間 32,403 ms
ジャッジサーバーID
(参考情報)
judge15 / judge13
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 986 ms
81,260 KB
testcase_01 AC 962 ms
81,220 KB
testcase_02 AC 967 ms
81,088 KB
testcase_03 AC 959 ms
81,548 KB
testcase_04 AC 960 ms
81,224 KB
testcase_05 AC 966 ms
81,488 KB
testcase_06 AC 963 ms
81,472 KB
testcase_07 AC 976 ms
81,668 KB
testcase_08 AC 970 ms
82,056 KB
testcase_09 AC 980 ms
82,176 KB
testcase_10 AC 965 ms
80,916 KB
testcase_11 AC 975 ms
81,168 KB
testcase_12 AC 964 ms
81,060 KB
testcase_13 AC 962 ms
81,916 KB
testcase_14 AC 962 ms
81,212 KB
testcase_15 AC 959 ms
81,516 KB
testcase_16 AC 963 ms
81,248 KB
testcase_17 AC 978 ms
82,100 KB
testcase_18 AC 959 ms
80,960 KB
testcase_19 AC 970 ms
81,616 KB
testcase_20 AC 960 ms
81,272 KB
testcase_21 AC 962 ms
81,424 KB
testcase_22 AC 964 ms
81,488 KB
testcase_23 AC 965 ms
81,220 KB
testcase_24 AC 961 ms
81,196 KB
testcase_25 AC 960 ms
81,232 KB
testcase_26 AC 961 ms
81,852 KB
testcase_27 AC 959 ms
81,696 KB
testcase_28 AC 956 ms
81,104 KB
testcase_29 AC 962 ms
80,724 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from time import time
from random import random, randrange, choice, choices
from heapq import heappop, heappush
from math import exp

START = time()
INF=10**9


def get_time(START):
    return time() - START

def dist(p1,p2):
    x1,y1=p1
    x2,y2=p2
    return (x1-x2)**2 + (y1-y2)**2

def cost(i: int, j: int):
    d = dist(Terminals[i], Terminals[j])
    if i<N: d*=5
    if j<N: d*=5
    return d

def calc(i: int, j: int, ans ):
    d = dist(Terminals[ans[i]], Terminals[ans[j]])
    if ans[i]<N: d*=5
    if ans[j]<N: d*=5
    return d

def dijkstra(s: int, g: int,mind: int):
    _dist = [INF] *(N+M)
    prev = [-1]*(N+M)

    que = [(0,s)]
    _dist[s]=0
    while que:
        d, v = heappop(que)
        if _dist[v] < d:
            continue
        if v==g:
            break
        for nv in range(T):
            nd = dist(Terminals[v],Terminals[nv])
            if v<N: nd*=5
            if nv<N: nd*=5
            nd += d
            # print(nv,nd,file = sys.stderr)s
            if nd < _dist[nv] and nd <= mind:
                prev[nv] = v
                _dist[nv] = nd
                heappush(que,(nd,nv))

    res=[]
    v = g
    while v!=s:
        res.append(v)
        v = prev[v]
    return res[::-1]


def kmeansplus(k,X_):
    #=======
    #k: クラスタ数
    #X_ : データ点の座標を格納した配列
    #=======
    X = X_[:]
    clusters = [] #初期重心を管理するリスト

    #1. データから一点を選択
    centroid_id = choice([i for i in range(len(X))])
    clusters.append(X[centroid_id])
    
    X.pop(centroid_id) #選択した点を入力データのリストから除外
    
    # 4. k 個のクラスタ中心を得られるまで計算を繰り返す。
    while len(clusters) < k:
        dists = []

        #2. 各データ点 と各クラスタ中心との距離を計算し、最も近いものを取り出す
        for i in range(len(X)):
            d = INF
            for j in range(len(clusters)):
                d = min(d, dist(X[i], clusters[j]) ) #今まで見た最短距離と今見てる距離との小さい方を選択
            dists.append(d)
        
        len_dist = len(dists)
        
        #3. 確率分布にしたがって、データ点を1点選択してクラスタ中心とする。
        new_c = choices([i for i in range(len_dist)], weights=dists, k=1)[0] #確率分布から点を一点取り出す
        clusters.append(X[new_c])
        
        X.pop(new_c) #取り出した要素を消去
  
    return clusters

def calc_score(ans):
    score = 0
    for i in range(len(ans)-1):
        score += calc(i,i+1,ans)
    return score

def calc_score2(ans):
    score = 0
    for i in range(len(ans)-1):
        score += Ecost[ans[i]][ans[i+1]]
    return score


def probability(diff):
    start_temp=50
    end_temp=10
    temp = start_temp + (end_temp - start_temp) * get_time(START) / 0.85
    return exp(diff/temp)

# インプット
N,M=map(int,input().split())
Terminals = [tuple(map(int,input().split())) for _ in range(N)]

Stations = kmeansplus(M,Terminals) # k-means++法で配置
Terminals += Stations
T=len(Terminals)
#各Terminalの移動コストをワーシャルフロイトで計算
Ecost=[[INF]*T for i in range(T)]
for i,t1 in enumerate(Terminals):
    for j,t2 in enumerate(Terminals):
        Ecost[i][j]=cost(i,j)
#ワーシャルフロイト
for k in range(T):
    for i in range(T):
        for j in range(T):
            Ecost[i][j]=min(Ecost[i][j],Ecost[i][k]+Ecost[k][j])

#初期解をgreedyに構築
v=0
ans = [v]
visited = [False]*N
visited[v]=True

for _ in range(N-1):
    best_v = -1
    best_dist = INF
    for nv in range(N):
        if visited[nv]: continue
        d = Ecost[v][nv]
        if d < best_dist:
            best_dist = d
            best_v= nv
    assert best_v!=-1
    
    ans.append(best_v)

    v = best_v
    visited[v]=True

ans.append(0)

# 山登り
n = len(ans)
loop_cnt = 0        
best_score = calc_score(ans)
dx=(10,-10,0,0,5,5,-5,-5)
dy=(0,0,10,-10,5,-5,5,-5)
print(get_time(START),file=sys.stderr)
while get_time(START) < 0.80:
    loop_cnt+=1
    v1 = randrange(1,n-1)
    v2 = randrange(1,n-1)
    while v1==v2:
        v2 = randrange(1,n-1)

    ans[v1],ans[v2] = ans[v2],ans[v1]
    current_score = calc_score2(ans)
    diff = best_score - current_score
    if diff>0:
        best_score = current_score
        #print("swap",current_score, file=sys.stderr)
    else:
        ans[v1],ans[v2] = ans[v2],ans[v1]




# アウトプット

output=[0]
for i in range(len(ans)-1):
    output+=dijkstra(ans[i],ans[i+1],Ecost[ans[i]][ans[i+1]])
loopcnt2=0
while get_time(START)<0.87:
    loopcnt2+=1
    v = randrange(N,N+M)
    original = Terminals[v][:]
    best_i = -1
    sub_best_score = best_score
    for i in range(8):
        Terminals[v] = (original[0]+dx[i],original[1]+dy[i])
        current_score = calc_score(output)
        if current_score < sub_best_score:
            sub_best_score = current_score
            best_i = i

    if best_i != -1:
        best_score = sub_best_score
        Terminals[v] = (original[0]+dx[best_i], original[1]+dy[best_i])
        #print("move",best_score, file=sys.stderr)
    else:
        Terminals[v]=original[:]
for pos in Terminals[N:]:
    print(*pos)
print(len(output))
for v in output:
    if v<N:
        print(1,v+1)
    else:
        print(2,v+1-N)

print("Score", 10**9//(1000+calc_score(output)**0.5), file=sys.stderr)
print("Loop", loop_cnt, loopcnt2, file=sys.stderr)
print("time:",get_time(START) * 1000, file=sys.stderr)
0