結果

問題 No.5007 Steiner Space Travel
コンテスト
ユーザー jakujaku12
提出日時 2023-04-29 12:06:38
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 8,498 bytes
記録
コンパイル時間 321 ms
コンパイル使用メモリ 87,204 KB
実行使用メモリ 83,724 KB
スコア 273,060
最終ジャッジ日時 2023-04-29 12:06:53
合計ジャッジ時間 5,676 ms
ジャッジサーバーID
(参考情報)
judge14 / judge12
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 1 TLE * 1 -- * 28
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

from itertools import permutations
import sys
from time import time
from random import random, randrange, choice, choices, randint
from heapq import heappop, heappush
from math import exp

START = time()
INF=10**9


def get_time(START):
    return time() - START

def dist(p1,p2):
    x1,y1=p1
    x2,y2=p2
    return (x1-x2)**2 + (y1-y2)**2

def cost(i: int, j: int):
    d = dist(Terminals[i], Terminals[j])
    if i<N: d*=5
    if j<N: d*=5
    return d

def calc(i: int, j: int, ans ):
    d = dist(Terminals[ans[i]], Terminals[ans[j]])
    if ans[i]<N: d*=5
    if ans[j]<N: d*=5
    return d

def dijkstra(s: int, g: int,mind: int):
    _dist = [INF] *(N+M)
    prev = [-1]*(N+M)

    que = [(0,s)]
    _dist[s]=0
    while que:
        d, v = heappop(que)
        if _dist[v] < d:
            continue
        if v==g:
            break
        for nv in range(T):
            nd = dist(Terminals[v],Terminals[nv])
            if v<N: nd*=5
            if nv<N: nd*=5
            nd += d
            # print(nv,nd,file = sys.stderr)s
            if nd < _dist[nv] and nd <= mind:
                prev[nv] = v
                _dist[nv] = nd
                heappush(que,(nd,nv))

    res=[]
    v = g
    while v!=s:
        res.append(v)
        v = prev[v]
    return res[::-1]



class KMEANS:

    def __init__(self):
        self.reps=[]
        self.dists=[]
        self.clusters=[]
        self.keep_flag=True

    # c_data: クラスタリング対象データ
    # k: クラスタ数
    def Clustering(self, c_datas, k):
        # 代表点を初期化
        self.RepInit(c_datas, k)

        while self.keep_flag:
            # 代表点と点の距離を計算
            self.ClusterDist(c_datas)
            # print(self.dists,file=sys.stderr)

            # 所属クラスタを更新
            self.ClusterUpdate()

            # 代表点を更新
            self.RepUpdate(c_datas)

    # クラスタ代表点の初期化 + 代表点と各点の距離と所属クラスタを格納するリストの初期化
    def RepInit(self, c_datas, k):
        self.reps = [(250,250),(500,250),(750,250),(750,500),(750,750),(500,750),(250,750),(250,500)]
        self.dists = [[-1 for j in range(k)] for i in range(len(c_datas))]
        self.clusters=[-1 for i in range(len(c_datas))]

    # 代表点と点の距離を計算
    def ClusterDist(self, c_datas):
        for (i, c_data) in enumerate(c_datas):
            for (j, rep) in enumerate(self.reps):
                # 各点の代表点との距離を計算
                self.dists[i][j] = dist(c_data, rep)

    # 所属クラスタ更新
    def ClusterUpdate(self):
        flag=False
        for (i, dist) in enumerate(self.dists):
            # クラスタ更新があった場合はwhileループのフラグをTrueに維持
            best_dist=INF
            best_v=-1
            for j,d in enumerate(dist):
                if d<best_dist:
                    best_dist = d
                    best_v = j
            if self.clusters[i] != best_v:
                flag=True
            # 距離のリストから最小値の引数を得る
            self.clusters[i] = best_v
        self.keep_flag=flag

     # クラスタの代表点を更新
    def RepUpdate(self, c_datas):
        for c_num in range(len(self.reps)):
            cluster_points=[]
            for i, (cluster, c_data) in enumerate(zip(self.clusters,c_datas)):
                if cluster == c_num:
                    # clauster_pointsにc_numクラスタの点を追加
                    cluster_points.append(c_data)
            if len(cluster_points) == 0:
                cluster_points.append(self.reps[c_num])
            # 点の平均を求め代表点を更新
            self.reps[c_num] = (sum([x for x,y in cluster_points])//len(cluster_points),sum([y for x,y in cluster_points])//len(cluster_points))
    

def calc_score(ans):
    score = 0
    for i in range(len(ans)-1):
        score += calc(i,i+1,ans)
    return score

def calc_score2(ans):
    score = 0
    for i in range(len(ans)-1):
        score += Ecost[ans[i]][ans[i+1]]
    return score


def probability(diff):
    start_temp=50
    end_temp=10
    temp = start_temp + (end_temp - start_temp) * get_time(START) / 0.85
    return exp(diff/temp)

# インプット
N,M=map(int,input().split())
Terminals = [tuple(map(int,input().split())) for _ in range(N)]
KM=KMEANS()
KM.Clustering(Terminals,M)
Stations = KM.reps[:]
print(Stations,file= sys.stderr)
Terminals += Stations
T=len(Terminals)
#各Terminalの移動コストをワーシャルフロイトで計算
Ecost=[[INF]*T for i in range(T)]
for i,t1 in enumerate(Terminals):
    for j,t2 in enumerate(Terminals):
        Ecost[i][j]=cost(i,j)
#ワーシャルフロイト
for k in range(T):
    for i in range(T):
        for j in range(T):
            Ecost[i][j]=min(Ecost[i][j],Ecost[i][k]+Ecost[k][j])

#ステーションを回る順序を全探索で決定
#ステーション毎に貪欲に回る順序を決定
StationCluster=[[] for i in range(M)]
TerminalsStation=[0]*N
for i in range(N):
    best_dist=INF
    best_st=-1
    for j in range(M):
        d = dist(Terminals[i],Terminals[j+N])
        if d<best_dist:
            best_dist=d
            best_st=j
    assert best_st!=-1
    StationCluster[best_st].append(i)
    TerminalsStation[i]=best_st
best_dist = INF
best_per = []
for per in permutations(range(M-1),M-1):
    station_list=[TerminalsStation[0]+N]
    for s in per:
        if s>=TerminalsStation[0]:s+=1
        station_list.append(s+N)
    station_list.append(TerminalsStation[0]+N)
    d = calc_score2(station_list)
    if d<best_dist:
        best_dist=d
        best_per=station_list[:]
print(best_per,file=sys.stderr)        


        


ans = [0]
for i in range(M):
    for v in StationCluster[best_per[i]-N]:
        if v==0:
            continue
        ans.append(v)
ans.append(0)
print(ans, file = sys.stderr)
# print(StationCluster, file = sys.stderr)

Tind=[0]*N
Tinv=[0]*N
for i,x in enumerate(ans[:N]):
    Tind[x]=i
    Tinv[i]=x

# 山登り
n = len(ans)
loop_cnt = 0
update_cnt=0        
best_score = calc_score2(ans)
dx=(10,-10,0,0,5,5,-5,-5)
dy=(0,0,10,-10,5,-5,5,-5)
# print(get_time(START),file=sys.stderr)
while get_time(START) < 0.75:
    loop_cnt+=1
    p=random()
    if p<0.2:
        v1 = randrange(1,n-1)
        v2 = randrange(1,n-1)
        while v1==v2:
            v2 = randrange(1,n-1)

        ans[v1],ans[v2] = ans[v2],ans[v1]
        current_score = calc_score2(ans)
        diff = best_score - current_score
        if diff>0:
            best_score = current_score
            update_cnt += 1
            #print("swap",current_score, file=sys.stderr)
        else:
            ans[v1],ans[v2] = ans[v2],ans[v1]
    else:
        #randomでクラスターを選んでその中でSwap
        v1=randint(1,n-1)
        v2 = choice(StationCluster[TerminalsStation[ans[v1]]])
        while ans[v1]==v2 or v2 == 0:
            v2 = choice(StationCluster[TerminalsStation[ans[v1]]])
        ans[v1],ans[Tind[v2]] = ans[Tind[v2]],ans[v1]
        current_score = calc_score2(ans)
        diff = best_score - current_score
        if diff>0:
            best_score = current_score
            update_cnt += 1
            #print("swap",current_score, file=sys.stderr)
        else:
            ans[v1],ans[Tind[v2]] = ans[Tind[v2]],ans[v1]





# アウトプット
ans.append(0)
output=[0]
for i in range(len(ans)-1):
    output+=dijkstra(ans[i],ans[i+1],Ecost[ans[i]][ans[i+1]])
loopcnt2=0
while get_time(START)<0.85:
    loopcnt2+=1
    v = randrange(N,N+M)
    original = Terminals[v][:]
    best_i = -1
    sub_best_score = best_score
    for i in range(8):
        Terminals[v] = (original[0]+dx[i],original[1]+dy[i])
        current_score = calc_score(output)
        if current_score < sub_best_score:
            sub_best_score = current_score
            best_i = i

    if best_i != -1:
        best_score = sub_best_score
        Terminals[v] = (original[0]+dx[best_i], original[1]+dy[best_i])
        #print("move",best_score, file=sys.stderr)
    else:
        Terminals[v]=original[:]
for pos in Terminals[N:]:
    print(*pos)
print(len(output))
for v in output:
    if v<N:
        print(1,v+1)
    else:
        print(2,v+1-N)

print("Score", 10**9//(1000+calc_score(output)**0.5), file=sys.stderr)
print("Loop", loop_cnt, loopcnt2, update_cnt, file=sys.stderr)
print("time:",get_time(START) * 1000, file=sys.stderr)
0