結果

問題 No.5007 Steiner Space Travel
ユーザー dna4_dna4_
提出日時 2023-04-24 16:52:50
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 7,694 bytes
コンパイル時間 433 ms
コンパイル使用メモリ 87,468 KB
実行使用メモリ 79,924 KB
スコア 0
最終ジャッジ日時 2023-04-24 16:53:05
合計ジャッジ時間 11,351 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
testcase_02 RE -
testcase_03 RE -
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
testcase_18 RE -
testcase_19 RE -
testcase_20 RE -
testcase_21 RE -
testcase_22 RE -
testcase_23 RE -
testcase_24 RE -
testcase_25 RE -
testcase_26 RE -
testcase_27 RE -
testcase_28 RE -
testcase_29 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import time
import random
import math

random.seed(42)

INF = 10**18

alpha = 5
alpha2 = alpha * alpha

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

class TimeKeeper:
    """
    時間を管理するクラス
    時間制限を秒単位で指定してインスタンスをつくる
    """
 
    def __init__(self, time_threshold) -> None:
        self.start_time_ = time.time()
        self.time_threshold_ = time_threshold
 
    def isTimeOver(self) -> bool:
        """
        インスタンスを生成した時から指定した時間制限を超過したか判断する  
        超過している場合にTrue  
        """
        return time.time() - self.start_time_ - self.time_threshold_ >= 0
 
    def time_msec(self) -> int:
        """経過時間をミリ秒単位で返す"""
        return int((time.time() - self.start_time_) * 1000)

    def time_sec(self) -> int:
        """経過時間を秒単位で返す(time_msecの使用を推奨)"""
        return time.time()-self.start_time_

class Kmeans:
    
    def __init__(self, X:list, n_data:int, k:int):
        self.x = [[t.x, t.y] for t in X]
        self.n_data = n_data
        self.k = k

    def init_centroid(self):
        idx = random.sample(range(self.n_data), self.k)
        centroids = [self.x[i] for i in idx]
        return centroids
    
    def compute_distance(self, centroids):
        distances = []
        for x in self.x:
            dist = [math.sqrt(sum([(a - b) ** 2 for a, b in zip(x, centroid)])) for centroid in centroids]
            distances.append(dist)
        return distances
    
    def clustering(self):
        centroids = self.init_centroid()
        new_cluster = [0]*self.n_data
        cluster = [0]*self.n_data
        for epoch in range(300):
            distances = self.compute_distance(centroids)
            new_cluster = [min(range(len(d)), key=lambda i: d[i]) for d in distances]
            for idx_centroid in range(self.k):
                x_in_cluster = [self.x[i] for i in range(self.n_data) if new_cluster[i] == idx_centroid]
                if x_in_cluster:
                    centroids[idx_centroid] = [int(sum(coord)/len(x_in_cluster)) for coord in zip(*x_in_cluster)]
            if new_cluster == cluster:
                break
        cluster = new_cluster
        eprint(centroids)
        eprint(cluster)
        return centroids


class Input:
    def __init__(self, N:int, M:int, ab:list) -> None:
        self.N = N
        self.M = M
        self.ab = ab

class Parser:

    def __init__(self, input_type:int):
        self.flag = input_type

    def parse(self):
        if self.flag == -1:
            inp:Input = self.parse_input()
        else:
            inp:Input = self.parse_input_file(self.flag)
        return inp
    
    def parse_input(self) -> Input:
        N,M = map(int,input().split())
        ab = [list(map(int,input().split())) for i in range(N)]
        return Input(N,M,ab)


    def parse_input_file(self,num) -> Input:
        cnt = str(num).zfill(4)
        PATH = f"./in/{cnt}.txt"
        with open(PATH) as f:
            l = [s.strip() for s in f.readlines()]
            N, M = map(int,l[0].split())
            ab = [list(map(int,s.split())) for s in l[1:]]
            return Input(N, M, ab)

class Transit:

    def __init__(self, id:int, x:int, y:int, type:int) -> None:
        """
        id:int id of planet or station
        x:int x coordinate
        y:int y coordinate
        type:int 1 planet, 2 station
        """
        self.id = id
        self.x = x
        self.y = y
        self.type = type

    def __str__(self) -> str:
        return f"({self.id},{self.x},{self.y},{self.type})"

class State:
    
    def __init__(self, order:list, q_planets:list, q_stations:list) -> None:
        """
        order:list visited order 
        q_stations:list[(int,int)] coordinates of space station
        """
        self.order = order
        self.q_planets = q_planets
        self.q_stations = q_stations
    
    def cal_dist(self, v1:Transit, v2:Transit) -> float:
        """
        return distance between v1 and v2 weighted by coefficient
        """
        x1,y1 = v1.x, v1.y
        x2,y2 = v2.x, v2.y
        coef = alpha
        if v1.type == 1 and v2.type == 1: coef = alpha2 # planet to planet
        elif v1.type == 2 and v2.type == 2: coef = 1 # station to station
        d = ((x1-x2)**2+(y1-y2)**2) * coef
        return d
    
    def cal_score(self):
        score = 0
        for i in range(len(self.order)-1):
            score += self.cal_dist(self.order[i], self.order[i+1])
        return int(pow(10,9)/(1000+score**0.5))
    
class Output:
    
    def __init__(self, state:State) -> None:
        self.order = state.order
        self.q_stations = state.q_stations

    def ans(self):
        for transition in self.q_stations:
            print(transition.x, transition.y)
        print(len(self.order))
        for transition in self.order:
            print(transition.type, transition.id+1)

class Solver:
    def __init__(self, state:State) -> None:
        self.state = state

    def solve(self):
        self.state.order.append(self.state.q_planets[0])
        visited = [0]*len(self.state.q_planets)
        visited[0] = 1
        now = self.state.q_planets[0]
        next = Transit(-1,-1,-1,-1)
        n_visited = 1
        while n_visited < len(self.state.q_planets):
            d_min = INF
            for transtion in self.state.q_planets:
                if visited[transtion.id] == 1: continue
                d = self.state.cal_dist(now, transtion)
                if d_min > d:
                    d_min = d
                    next = transtion
            if now.type != 2: #station to stationを許可するときはこのif文を消す 要改善
                for transtion in self.state.q_stations:
                    if now == transtion: continue
                    d = self.state.cal_dist(now, transtion)
                    if d_min > d:
                        d_min = d
                        next = transtion
            now = next
            self.state.order.append(next)
            if next.type == 1 and visited[next.id] == 0:
                visited[next.id] = 1
                n_visited += 1

        self.state.order.append(self.state.q_planets[0])
        return self.state

def main():

    parser = Parser(0)
    input = parser.parse()
    
    q_planets = []
    for i in range(input.N):
        q_planets.append(Transit(id = i, x = input.ab[i][0], y = input.ab[i][1], type = 1))

    kmeans = Kmeans(q_planets, 100, 8)
    a = kmeans.clustering()
    q_stations = []
    for i in range(input.M):
        q_stations.append(Transit(id = i,x = a[i][0],y = a[i][1],type = 2))
    state = State([], q_planets, q_stations)
    solver = Solver(state)
    best_ans = solver.solve()
    best_score = best_ans.cal_score()
    eprint(best_score)

    tmp_stations = best_ans.q_stations
    timeKeeper2 = TimeKeeper(0.84)
    while not timeKeeper2.isTimeOver():
        order = []
        q_stations = []
        for i in range(input.M):
            q_stations.append(Transit(id = i,x = tmp_stations[i].x+random.randrange(-20,20), y = tmp_stations[i].y+random.randrange(-20,20), type = 2))
        state = State(order,q_planets,q_stations)
        solver = Solver(state)
        ans = solver.solve()
        score = ans.cal_score()
        #eprint(score)
        if score > best_score:
            best_score = score
            best_ans = ans
            tmp_stations = q_stations
    eprint(best_score)
    output = Output(best_ans)
    output.ans()


if __name__ == "__main__":
    main() 
0