結果

問題 No.5007 Steiner Space Travel
ユーザー prussian_coderprussian_coder
提出日時 2023-04-25 10:23:09
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 922 ms / 1,000 ms
コード長 10,035 bytes
コンパイル時間 2,396 ms
コンパイル使用メモリ 86,860 KB
実行使用メモリ 98,044 KB
スコア 8,325,836
最終ジャッジ日時 2023-04-25 10:23:37
合計ジャッジ時間 28,248 ms
ジャッジサーバーID
(参考情報)
judge13 / judge12
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 730 ms
93,692 KB
testcase_01 AC 714 ms
93,056 KB
testcase_02 AC 753 ms
94,820 KB
testcase_03 AC 690 ms
92,764 KB
testcase_04 AC 747 ms
93,996 KB
testcase_05 AC 734 ms
93,616 KB
testcase_06 AC 725 ms
92,788 KB
testcase_07 AC 808 ms
94,440 KB
testcase_08 AC 816 ms
94,440 KB
testcase_09 AC 747 ms
93,632 KB
testcase_10 AC 765 ms
93,968 KB
testcase_11 AC 761 ms
94,452 KB
testcase_12 AC 832 ms
94,108 KB
testcase_13 AC 771 ms
93,392 KB
testcase_14 AC 922 ms
98,044 KB
testcase_15 AC 773 ms
93,892 KB
testcase_16 AC 777 ms
94,304 KB
testcase_17 AC 768 ms
93,948 KB
testcase_18 AC 699 ms
92,460 KB
testcase_19 AC 789 ms
94,416 KB
testcase_20 AC 744 ms
93,908 KB
testcase_21 AC 766 ms
94,600 KB
testcase_22 AC 781 ms
96,280 KB
testcase_23 AC 777 ms
94,692 KB
testcase_24 AC 785 ms
94,888 KB
testcase_25 AC 741 ms
93,476 KB
testcase_26 AC 706 ms
93,504 KB
testcase_27 AC 735 ms
93,904 KB
testcase_28 AC 717 ms
92,756 KB
testcase_29 AC 746 ms
93,864 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#from matplotlib import pyplot as plt

def visualize(N,M,pos,center_pos,allocate):
    fig = plt.figure(figsize=(10,10))
    for i in range(N):
        plt.plot(pos[i][0],pos[i][1],color=color_ls[allocate[i]],marker=".")
    for i in range(M):
        plt.plot(center_pos[i][0],center_pos[i][1],color=color_ls[i],marker="*")
    plt.show()
    plt.close()


INF=10**20

import random
from pathlib import Path
import time


LOCAL = False
in_path = "./test"
img_path = "./image"
color_ls = ["red","blue","green","orange","gray","pink","cyan","black"]

def read_data(file):
    if LOCAL:
        with open(file,mode="r") as f:
            data = f.readlines()
        N,M = map(int,data[0].split())
        pos = [[int(x) for x in data[i+1].split()] for i in range(N)]
    else:
        N,M=map(int,input().split())
        pos = [[int(x) for x in input().split()] for i in range(N)]
    return N,M,pos

#2点間の距離を返す 
def dist(p1,p2,a=25): 
    return ((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2) * a

#中心点と各惑星の距離を全探索し、短い順に並びかえる
def calc_dist_from_center(N,M,center_pos,pos):
    dist_list = []
    for m in range(M):
        for n in range(N):
            d = dist(pos[n],center_pos[m])
            dist_list.append((d,m,n))
    return sorted(dist_list)

#中心点から近い惑星順にクラスわけする
def allocate_cluster(N,M,center_pos,pos):
    dist_list = calc_dist_from_center(N,M,center_pos,pos)
    allocate = [-1]*N
    cluster_counts = [0]*M
    total_count = 0
    dist_sum = 0
    for d,m,n in dist_list:
        if allocate[n]==-1:
            allocate[n]=m
            cluster_counts[m]+=1
            total_count+=1
            dist_sum +=d
        if total_count==N:
            break
    return allocate,cluster_counts,dist_sum
    
#クラスター分けされた点をもとに、k-means法に倣って中心点を算出する。
# グループカウントが多いものは中心点を2つにして、クラスターを分ける。
def calc_center(N,M,pos,allocate,cluster_counts,max_size = 15,split_mode = True):
    point_sums = [[0,0] for _ in range(M)]
    for i in range(N):
        m=allocate[i]
        point_sums[m][0]+=pos[i][0]
        point_sums[m][1]+=pos[i][1]
    center_pos = []

    #clsuter_countが大きい順に見ていき、中心点がM個を超えたら中断する
    if split_mode:
        order = sorted([(cluster_counts[i],i) for i in range(M)],reverse=True)
    else:
        order = [(cluster_counts[i],i) for i in range(M)]
    for _,m in order:  
        if cluster_counts[m]==0:
            center_pos.append([random.randint(0,1000),random.randint(0,1000)])
        else:
            center_pos.append([point_sums[m][0]//cluster_counts[m],point_sums[m][1]//cluster_counts[m]])
            if cluster_counts[m]>=max_size and split_mode:
                x = point_sums[m][0]//cluster_counts[m]
                y = point_sums[m][1]//cluster_counts[m]
                dx = random.randint(-20,20)
                dy = random.randint(-20,20)
                while not (0<=x+dx<=1000 and 0<y+dy<=1000):
                    dx = random.randint(-5,5)
                    dy = random.randint(-5,5)
                center_pos.append([x+dx,y+dy])
        if len(center_pos)==M:
            break

    return center_pos
    
#K-means法に倣って惑星の点をグループ分けする
def clustering(N,M,pos,max_size):
    center_pos = [[random.randint(0,1000),random.randint(0,1000)] for _ in range(M)] #ランダム初期化
    allocate,cluster_counts,dist_sum = allocate_cluster(N,M,center_pos,pos)
    loop = 0

    while True:
        center_pos = calc_center(N,M,pos,allocate,cluster_counts,max_size=max_size)
        allocate,cluster_counts,dist_sum = allocate_cluster(N,M,center_pos,pos)        
        #print(loop,cluster_counts)
        if loop>=20 and max(cluster_counts)<=max_size:
            break
        loop+=1    

    center_pos = calc_center(N,M,pos,allocate,cluster_counts,split_mode=False)
    #visualize(N,M,pos,center_pos,allocate)

    return center_pos,allocate


def f(S,x,n):
    return S*(n+1)+x


#クラスター内にてBitDPでTSPを解く O(n^2*2^n)
def tsp(id_list,pos,center_pos):    
    #point 0~n-1が惑星、nが中心点
    n=len(id_list)-1
    dp = [INF]*(n+1)*(1<<n)
    dp[f(0,n,n)]=0
    for S in range(1<<n):
        for s in range(n+1):
            if dp[f(S,s,n)]==INF:
                continue
            for t in range(n):
                if (S>>t)&1:
                    continue
                S2 = S|(1<<t)
                if s!=n:
                    dp[f(S2,t,n)]=min(dp[f(S,s,n)] + dist(pos[id_list[s]],pos[id_list[t]]),dp[f(S2,t,n)])
                else:
                    dp[f(S2,t,n)]=min(dp[f(S,s,n)] + dist(center_pos,pos[id_list[t]],a=5),dp[f(S2,t,n)])
                dp[f(S2,n,n)]=min(dp[f(S2,t,n)] + dist(center_pos,pos[id_list[t]],a=5),dp[f(S2,n,n)])


    #BitDPから復元            
    path_list = [n]
    state = (1<<n)-1
    now = n
    v = dp[-1]
    e = 10**(-5)
    while state != 0 or now !=n:
        found = False
        if now == n:
            for t in range(n):
                d = dist(center_pos,pos[id_list[t]],a=5)
                if dp[f(state,t,n)]==INF:
                    continue
                if v - dp[f(state,t,n)] >= d - e:
                    path_list.append(t)
                    now = t
                    v -= d
                    found = True
                    break
        else:
            state = state ^ (1<<now)
            for t in range(n+1):
                if dp[f(state,t,n)]==INF:
                    continue
                if t!=n:
                    if not (state>>t)&1:
                        continue
                    d = dist(pos[id_list[now]],pos[id_list[t]])
                else:
                    d = dist(pos[id_list[now]],center_pos,a=5)
                if v - dp[f(state,t,n)] >= d - e:
                    path_list.append(t)
                    now = t
                    v -= d
                    found=True
                    break


    return [id_list[i] for i in path_list]


def tsp_between_space(center_pos):
    n = 8
    dp = [[INF for _ in range(n)] for _ in range(1<<n)]
    dp[0][-1]=0
    for S in range(1<<n):
        for s in range(n):
            if dp[S][s]==INF:
                continue
            for t in range(n):
                if (S>>t)&1:
                    continue
                S2 = S|(1<<t)
                d = dist(center_pos[s],center_pos[t],a=1)
                dp[S2][t]=min(dp[S][s] + d, dp[S2][t])


    #BitDPから復元            
    state = (1<<n)-1
    s = n-1
    path_list = [s]
    v = dp[-1][-1]
    e = 10**(-5)
    while state != 0:
        state ^=(1<<s)
        for t in range(n):
            if not (state>>t)&1:
                continue
            if dp[state][t]==INF:
                continue
            d = dist(center_pos[s],center_pos[t],a=1)
            if v - dp[state][t] >= d - e:
                path_list.append(t)
                s = t
                v -= d
                break
    return path_list        



def adjust_center_pos(path,pos,M):
    L=len(path)
    edge_count = [0]*M
    edge_x_sum = [0]*M
    edge_y_sum = [0]*M
    for i in range(L-1):
        if path[i]<0 and path[i+1]>=0:
            p,q = -path[i]-1,path[i+1]
        elif path[i]>=0 and path[i+1]<0:
            p,q = -path[i+1]-1,path[i]
        else:
            continue
        edge_count[p]+=1
        edge_x_sum[p]+=pos[q][0]
        edge_y_sum[p]+=pos[q][1]
    center_pos = [[edge_x_sum[m]//edge_count[m],edge_y_sum[m]//edge_count[m]] for m in range(M)]
    return center_pos

def calc_score(N,M,path,pos,center_pos):
    score = 0
    L = len(path)
    for i in range(L-1):
        a=1
        if path[i]>=0:
            p1 = pos[path[i]]
            a*=5
        else:
            p1 = center_pos[-path[i]-1]
        if path[i+1]>=0:
            p2 = pos[path[i+1]]
            a*=5
        else:
            p2 = center_pos[-path[i+1]-1]
        score += dist(p1,p2,a)
    return 10**9 / (1000 + score**0.5)
    
    



def main(N,M,pos):
    #K-meansで8個に分割
    center_pos,allocate = clustering(N,M,pos,21)
    ans = []

    #クラスタ間での回る順番をTSPで決める
    space_order = tsp_between_space(center_pos)

    #クラスタ内で回る順番をTSPで決める
    for m in space_order:
        id_list = [i for i in range(N) if allocate[i]==m]
        n = len(id_list)
        #サイズが小さい時は全部使う
        if n<=12:
            ans += tsp(id_list+[-m-1],pos,center_pos[m])

        #サイズが大きすぎる時は、clusteringで2つに分割してそれぞれでTSP
        else:
            pos2 = [pos[i] for i in id_list]
            _,allocate2 = clustering(n,2,pos2,12)
            for i2 in range(2):
                id_list2 = [id_list[i] for i in range(n) if allocate2[i]==i2]
                ans += tsp(id_list2+[-m-1],pos,center_pos[m])

    #初期位置を0にする
    start_point = ans.index(0)
    ans = ans[start_point:]+ans[:start_point+1] #初期解

    #center-posの位置を補正(最小二乗法の要領)
    center_pos = adjust_center_pos(ans,pos,M)

    score = calc_score(N,M,ans,pos,center_pos)
        
    return center_pos,ans,score

def output(center_pos,ans):
    for x,y in center_pos:
        print(x,y)
    print(len(ans))
    for a in ans:
        if a<0:
            print(2,-a)
        else:
            print(1,a+1)

if LOCAL:
    file_ls = Path(in_path).glob("*.txt")
    for file in file_ls:
        print(file)
        N,M,pos = read_data(file)
        center_pos,ans = main(N,M,pos)
        print(ans)
        output(center_pos,ans)
else:
    N,M,pos = read_data("")
    best_score=0
    for _ in range(4):
        center_pos,ans,score = main(N,M,pos)
        if score>best_score:
            best_score = score
            best_pos = center_pos
            best_ans = ans
    output(best_pos,best_ans)



0