結果

問題 No.5007 Steiner Space Travel
ユーザー dna4_dna4_
提出日時 2023-04-24 17:05:40
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 7,703 bytes
コンパイル時間 475 ms
コンパイル使用メモリ 87,032 KB
実行使用メモリ 83,128 KB
スコア 7,572,692
最終ジャッジ日時 2023-04-24 17:06:12
合計ジャッジ時間 32,061 ms
ジャッジサーバーID
(参考情報)
judge13 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 949 ms
81,792 KB
testcase_01 AC 950 ms
81,916 KB
testcase_02 AC 948 ms
81,968 KB
testcase_03 AC 954 ms
82,868 KB
testcase_04 AC 952 ms
82,540 KB
testcase_05 AC 948 ms
81,828 KB
testcase_06 AC 946 ms
82,348 KB
testcase_07 AC 948 ms
82,304 KB
testcase_08 AC 948 ms
83,128 KB
testcase_09 AC 947 ms
82,360 KB
testcase_10 AC 950 ms
81,920 KB
testcase_11 AC 948 ms
82,440 KB
testcase_12 AC 949 ms
82,000 KB
testcase_13 AC 948 ms
82,092 KB
testcase_14 AC 950 ms
82,324 KB
testcase_15 AC 951 ms
82,068 KB
testcase_16 AC 950 ms
82,572 KB
testcase_17 AC 949 ms
82,584 KB
testcase_18 AC 948 ms
81,560 KB
testcase_19 AC 947 ms
81,988 KB
testcase_20 AC 947 ms
81,132 KB
testcase_21 AC 948 ms
82,920 KB
testcase_22 AC 954 ms
81,972 KB
testcase_23 AC 950 ms
81,912 KB
testcase_24 AC 946 ms
81,812 KB
testcase_25 WA -
testcase_26 AC 949 ms
81,760 KB
testcase_27 AC 951 ms
82,992 KB
testcase_28 AC 948 ms
82,452 KB
testcase_29 AC 950 ms
82,168 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import time
import random
import math

random.seed(42)

INF = 10**18

alpha = 5
alpha2 = alpha * alpha

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

class TimeKeeper:
    """
    時間を管理するクラス
    時間制限を秒単位で指定してインスタンスをつくる
    """
 
    def __init__(self, time_threshold) -> None:
        self.start_time_ = time.time()
        self.time_threshold_ = time_threshold
 
    def isTimeOver(self) -> bool:
        """
        インスタンスを生成した時から指定した時間制限を超過したか判断する  
        超過している場合にTrue  
        """
        return time.time() - self.start_time_ - self.time_threshold_ >= 0
 
    def time_msec(self) -> int:
        """経過時間をミリ秒単位で返す"""
        return int((time.time() - self.start_time_) * 1000)

    def time_sec(self) -> int:
        """経過時間を秒単位で返す(time_msecの使用を推奨)"""
        return time.time()-self.start_time_

class Kmeans:
    
    def __init__(self, X:list, n_data:int, k:int):
        self.x = [[t.x, t.y] for t in X]
        self.n_data = n_data
        self.k = k

    def init_centroid(self):
        idx = random.sample(range(self.n_data), self.k)
        centroids = [self.x[i] for i in idx]
        return centroids
    
    def compute_distance(self, centroids):
        distances = []
        for x in self.x:
            dist = [math.sqrt(sum([(a - b) ** 2 for a, b in zip(x, centroid)])) for centroid in centroids]
            distances.append(dist)
        return distances
    
    def clustering(self):
        centroids = self.init_centroid()
        new_cluster = [0]*self.n_data
        cluster = [0]*self.n_data
        for epoch in range(300):
            distances = self.compute_distance(centroids)
            new_cluster = [min(range(len(d)), key=lambda i: d[i]) for d in distances]
            for idx_centroid in range(self.k):
                x_in_cluster = [self.x[i] for i in range(self.n_data) if new_cluster[i] == idx_centroid]
                if x_in_cluster:
                    centroids[idx_centroid] = [int(sum(coord)/len(x_in_cluster)) for coord in zip(*x_in_cluster)]
            if new_cluster == cluster:
                break
        cluster = new_cluster
        #eprint(centroids)
        #eprint(cluster)
        return centroids


class Input:
    def __init__(self, N:int, M:int, ab:list) -> None:
        self.N = N
        self.M = M
        self.ab = ab

class Parser:

    def __init__(self, input_type:int):
        self.flag = input_type

    def parse(self):
        if self.flag == -1:
            inp:Input = self.parse_input()
        else:
            inp:Input = self.parse_input_file(self.flag)
        return inp
    
    def parse_input(self) -> Input:
        N,M = map(int,input().split())
        ab = [list(map(int,input().split())) for i in range(N)]
        return Input(N,M,ab)


    def parse_input_file(self,num) -> Input:
        cnt = str(num).zfill(4)
        PATH = f"./in/{cnt}.txt"
        with open(PATH) as f:
            l = [s.strip() for s in f.readlines()]
            N, M = map(int,l[0].split())
            ab = [list(map(int,s.split())) for s in l[1:]]
            return Input(N, M, ab)

class Transit:

    def __init__(self, id:int, x:int, y:int, type:int) -> None:
        """
        id:int id of planet or station
        x:int x coordinate
        y:int y coordinate
        type:int 1 planet, 2 station
        """
        self.id = id
        self.x = x
        self.y = y
        self.type = type

    def __str__(self) -> str:
        return f"({self.id},{self.x},{self.y},{self.type})"

class State:
    
    def __init__(self, order:list, q_planets:list, q_stations:list) -> None:
        """
        order:list visited order 
        q_stations:list[(int,int)] coordinates of space station
        """
        self.order = order
        self.q_planets = q_planets
        self.q_stations = q_stations
    
    def cal_dist(self, v1:Transit, v2:Transit) -> float:
        """
        return distance between v1 and v2 weighted by coefficient
        """
        x1,y1 = v1.x, v1.y
        x2,y2 = v2.x, v2.y
        coef = alpha
        if v1.type == 1 and v2.type == 1: coef = alpha2 # planet to planet
        elif v1.type == 2 and v2.type == 2: coef = 1 # station to station
        d = ((x1-x2)**2+(y1-y2)**2) * coef
        return d
    
    def cal_score(self):
        score = 0
        for i in range(len(self.order)-1):
            score += self.cal_dist(self.order[i], self.order[i+1])
        return int(pow(10,9)/(1000+score**0.5))
    
class Output:
    
    def __init__(self, state:State) -> None:
        self.order = state.order
        self.q_stations = state.q_stations

    def ans(self):
        for transition in self.q_stations:
            print(transition.x, transition.y)
        print(len(self.order))
        for transition in self.order:
            print(transition.type, transition.id+1)

class Solver:
    def __init__(self, state:State) -> None:
        self.state = state

    def solve(self):
        self.state.order.append(self.state.q_planets[0])
        visited = [0]*len(self.state.q_planets)
        visited[0] = 1
        now = self.state.q_planets[0]
        next = Transit(-1,-1,-1,-1)
        n_visited = 1
        while n_visited < len(self.state.q_planets):
            d_min = INF
            for transtion in self.state.q_planets:
                if visited[transtion.id] == 1: continue
                d = self.state.cal_dist(now, transtion)
                if d_min > d:
                    d_min = d
                    next = transtion
            if now.type != 2: #station to stationを許可するときはこのif文を消す 要改善
                for transtion in self.state.q_stations:
                    if now == transtion: continue
                    d = self.state.cal_dist(now, transtion)
                    if d_min > d:
                        d_min = d
                        next = transtion
            now = next
            self.state.order.append(next)
            if next.type == 1 and visited[next.id] == 0:
                visited[next.id] = 1
                n_visited += 1

        self.state.order.append(self.state.q_planets[0])
        return self.state

def main():

    timeKeeper2 = TimeKeeper(0.85)

    parser = Parser(-1)
    input = parser.parse()
    
    q_planets = []
    for i in range(input.N):
        q_planets.append(Transit(id = i, x = input.ab[i][0], y = input.ab[i][1], type = 1))

    kmeans = Kmeans(q_planets, 100, 8)
    a = kmeans.clustering()
    q_stations = []
    for i in range(input.M):
        q_stations.append(Transit(id = i,x = a[i][0],y = a[i][1],type = 2))
    state = State([], q_planets, q_stations)
    solver = Solver(state)
    best_ans = solver.solve()
    best_score = best_ans.cal_score()
    eprint(best_score)

    tmp_stations = best_ans.q_stations
    
    while not timeKeeper2.isTimeOver():
        order = []
        q_stations = []
        for i in range(input.M):
            q_stations.append(Transit(id = i,x = tmp_stations[i].x+random.randrange(-20,20), y = tmp_stations[i].y+random.randrange(-20,20), type = 2))
        state = State(order,q_planets,q_stations)
        solver = Solver(state)
        ans = solver.solve()
        score = ans.cal_score()
        #eprint(score)
        if score > best_score:
            best_score = score
            best_ans = ans
            tmp_stations = q_stations
    eprint(best_score)
    output = Output(best_ans)
    output.ans()


if __name__ == "__main__":
    main() 
0