結果

問題 No.5007 Steiner Space Travel
ユーザー prussian_coderprussian_coder
提出日時 2023-04-24 17:36:22
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 8,313 bytes
コンパイル時間 489 ms
コンパイル使用メモリ 87,744 KB
実行使用メモリ 94,740 KB
スコア 804,190
最終ジャッジ日時 2023-04-24 17:36:27
合計ジャッジ時間 4,955 ms
ジャッジサーバーID
(参考情報)
judge15 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 362 ms
88,804 KB
testcase_01 AC 386 ms
89,140 KB
testcase_02 AC 366 ms
88,768 KB
testcase_03 TLE -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #


INF=10**20

import random
from pathlib import Path
import time


LOCAL = False
in_path = "./test"
img_path = "./image"
color_ls = ["red","blue","green","orange","gray","pink","cyan","black"]

def read_data(file):
    if LOCAL:
        with open(file,mode="r") as f:
            data = f.readlines()
        N,M = map(int,data[0].split())
        pos = [[int(x) for x in data[i+1].split()] for i in range(N)]
    else:
        N,M=map(int,input().split())
        pos = [[int(x) for x in input().split()] for i in range(N)]
    return N,M,pos

#2点間の距離を返す 
def dist(p1,p2,a=25): 
    return ((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)*a

#中心点と各惑星の距離を全探索し、短い順に並びかえる
def calc_dist_from_center(N,M,center_pos,pos):
    dist_list = []
    for m in range(M):
        for n in range(N):
            d = dist(pos[n],center_pos[m])
            dist_list.append((d,m,n))
    return sorted(dist_list)

#中心点から近い惑星順にクラスわけする
def allocate_cluster(N,M,center_pos,pos):
    dist_list = calc_dist_from_center(N,M,center_pos,pos)
    allocate = [-1]*N
    cluster_counts = [0]*M
    total_count = 0
    for d,m,n in dist_list:
        if allocate[n]==-1:
            allocate[n]=m
            cluster_counts[m]+=1
            total_count+=1
        if total_count==N:
            break
    return allocate,cluster_counts
    
#クラスター分けされた点をもとに、k-means法に倣って中心点を算出する。
# グループカウントが多いものは中心点を2つにして、クラスターを分ける。
def calc_center(N,M,pos,allocate,cluster_counts,max_size = 15,split_mode = True):
    point_sums = [[0,0] for _ in range(M)]
    for i in range(N):
        m=allocate[i]
        point_sums[m][0]+=pos[i][0]
        point_sums[m][1]+=pos[i][1]
    center_pos = []

    #clsuter_countが大きい順に見ていき、中心点がM個を超えたら中断する
    if split_mode:
        order = sorted([(cluster_counts[i],i) for i in range(M)],reverse=True)
    else:
        order = [(cluster_counts[i],i) for i in range(M)]
    for _,m in order:  
        if cluster_counts[m]==0:
            center_pos.append([random.randint(0,1000),random.randint(0,1000)])
        else:
            center_pos.append([point_sums[m][0]//cluster_counts[m],point_sums[m][1]//cluster_counts[m]])
            if cluster_counts[m]>=max_size and split_mode:
                x = point_sums[m][0]//cluster_counts[m]
                y = point_sums[m][1]//cluster_counts[m]
                dx = random.randint(-20,20)
                dy = random.randint(-20,20)
                while not (0<=x+dx<=1000 and 0<y+dy<=1000):
                    dx = random.randint(-5,5)
                    dy = random.randint(-5,5)
                center_pos.append([x+dx,y+dy])
        if len(center_pos)==M:
            break

    return center_pos
    
#K-means法に倣って惑星の点をグループ分けする
def clustering(N,M,pos,max_size):
    center_pos = [[random.randint(0,1000),random.randint(0,1000)] for _ in range(M)] #ランダム初期化
    allocate,cluster_counts = allocate_cluster(N,M,center_pos,pos)
    loop = 0

    while True:
        center_pos = calc_center(N,M,pos,allocate,cluster_counts,max_size=max_size)
        allocate,cluster_counts = allocate_cluster(N,M,center_pos,pos)        
        #print(loop,cluster_counts)
        if loop>=20 and max(cluster_counts)<=max_size:
            break
        loop+=1    

    center_pos = calc_center(N,M,pos,allocate,cluster_counts,split_mode=False)
    #visualize(N,M,pos,center_pos,allocate)

    return center_pos,allocate,cluster_counts


def f(S,x,n):
    return S*(n+1)+x


#クラスター内にてBitDPでTSPを解く O(n^2*2^n)
def tsp(id_list,pos,center_pos):    
    #point 0~n-1が惑星、nが中心点
    n=len(id_list)-1
    dp = [INF]*(n+1)*(1<<n)
    dp[f(0,n,n)]=0
    for S in range(1<<n):
        for s in range(n+1):
            if dp[f(S,s,n)]==INF:
                continue
            for t in range(n):
                if (S>>t)&1:
                    continue
                S2 = S|(1<<t)
                if s!=n:
                    dp[f(S2,t,n)]=min(dp[f(S,s,n)] + dist(pos[id_list[s]],pos[id_list[t]]),dp[f(S2,t,n)])
                else:
                    dp[f(S2,t,n)]=min(dp[f(S,s,n)] + dist(center_pos,pos[id_list[t]],a=5),dp[f(S2,t,n)])
                dp[f(S2,n,n)]=min(dp[f(S2,t,n)] + dist(center_pos,pos[id_list[t]],a=5),dp[f(S2,n,n)])


    #BitDPから復元            
    path_list = [n]
    state = (1<<n)-1
    now = n
    v = dp[-1]
    e = 10**(-5)
    while state != 0 or now !=n:
        found = False
        if now == n:
            for t in range(n):
                d = dist(center_pos,pos[id_list[t]],a=5)
                if dp[f(state,t,n)]==INF:
                    continue
                if v - dp[f(state,t,n)] >= d - e:
                    path_list.append(t)
                    now = t
                    v -= d
                    found = True
                    break
        else:
            state = state ^ (1<<now)
            for t in range(n+1):
                if dp[f(state,t,n)]==INF:
                    continue
                if t!=n:
                    if not (state>>t)&1:
                        continue
                    d = dist(pos[id_list[now]],pos[id_list[t]])
                else:
                    d = dist(pos[id_list[now]],center_pos,a=5)
                if v - dp[f(state,t,n)] >= d - e:
                    path_list.append(t)
                    now = t
                    v -= d
                    found=True
                    break


    return [id_list[i] for i in path_list]


def tsp_between_space(center_pos,pos):
    n = 9
    dp = [[INF for _ in range(n)] for _ in range(1<<n)]
    dp[0][-1]=0
    for S in range(1<<n):
        for s in range(n):
            if dp[S][s]==INF:
                continue
            for t in range(n):
                if (S>>t)&1:
                    continue
                S2 = S|(1<<t)

                if s!=n-1 and t!=n-1:
                    d = dist(center_pos[s],center_pos[t],a=1)
                elif t!=n-1:
                    d = dist(pos[0],center_pos[t],a=5)
                elif s!=n-1:
                    d = dist(pos[0],center_pos[s],a=5)
                dp[S2][t]=min(dp[S][s] + d, dp[S2][t])


    #BitDPから復元            
    path_list = []
    state = (1<<n)-1
    s = n-1
    v = dp[-1][-1]
    e = 10**(-5)
    while state != 0:
        state ^=(1<<s)
        for t in range(n):
            if not (state>>t)&1:
                continue
            if dp[state][t]==INF:
                continue
            if s!=n-1 and t!=n-1:
                d = dist(center_pos[s],center_pos[t],a=1)
            elif t!=n-1:
                d = dist(pos[0],center_pos[t],a=5)
            elif s!=n-1:
                d = dist(pos[0],center_pos[s],a=5)
            if v - dp[state][t] >= d - e:
                path_list.append(t)
                s = t
                v -= d
                break
    return path_list        




def main(N,M,pos):
    center_pos,allocate,cluster_counts = clustering(N,M,pos,24)
    ans = [0]
    space_order = tsp_between_space(center_pos,pos)

    for m in space_order:
        id_list = [i for i in range(N) if allocate[i]==m]
        n = len(id_list)
        if n<=12:
            ans += tsp(id_list+[-m-1],pos,center_pos[m])
        else:
            pos2 = [pos[i] for i in id_list]
            _,allocate2,_ = clustering(n,2,pos2,10)
            for i2 in range(2):
                id_list2 = [id_list[i] for i in range(n) if allocate2[i]==i2]
                ans += tsp(id_list2+[-m-1],pos,center_pos[m])

        
    ans.append(0)
    return center_pos,ans

def output(center_pos,ans):
    for x,y in center_pos:
        print(x,y)
    print(len(ans))
    for a in ans:
        if a<0:
            print(2,-a)
        else:
            print(1,a+1)

if LOCAL:
    file_ls = Path(in_path).glob("*.txt")
    for file in file_ls:
        print(file)
        N,M,pos = read_data(file)
        center_pos,ans = main(N,M,pos)
        output(center_pos,ans)
else:
    N,M,pos = read_data("")
    center_pos,ans = main(N,M,pos)
    output(center_pos,ans)


# %%



0