結果

問題 No.5007 Steiner Space Travel
ユーザー titan23titan23
提出日時 2022-10-06 16:04:30
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 907 ms / 1,000 ms
コード長 3,178 bytes
コンパイル時間 546 ms
実行使用メモリ 86,436 KB
スコア 7,314,039
最終ジャッジ日時 2022-10-06 16:05:01
合計ジャッジ時間 30,054 ms
ジャッジサーバーID
(参考情報)
judge11 / judge12
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 888 ms
86,436 KB
testcase_01 AC 887 ms
85,488 KB
testcase_02 AC 895 ms
85,328 KB
testcase_03 AC 897 ms
85,604 KB
testcase_04 AC 904 ms
85,516 KB
testcase_05 AC 889 ms
85,924 KB
testcase_06 AC 894 ms
85,188 KB
testcase_07 AC 895 ms
85,552 KB
testcase_08 AC 892 ms
85,664 KB
testcase_09 AC 895 ms
85,964 KB
testcase_10 AC 895 ms
85,440 KB
testcase_11 AC 900 ms
85,272 KB
testcase_12 AC 890 ms
85,624 KB
testcase_13 AC 891 ms
85,776 KB
testcase_14 AC 893 ms
85,600 KB
testcase_15 AC 895 ms
85,968 KB
testcase_16 AC 895 ms
85,992 KB
testcase_17 AC 907 ms
86,264 KB
testcase_18 AC 892 ms
85,732 KB
testcase_19 AC 889 ms
85,484 KB
testcase_20 AC 896 ms
86,188 KB
testcase_21 AC 894 ms
85,600 KB
testcase_22 AC 890 ms
85,644 KB
testcase_23 AC 897 ms
85,836 KB
testcase_24 AC 896 ms
85,516 KB
testcase_25 AC 891 ms
85,724 KB
testcase_26 AC 893 ms
85,692 KB
testcase_27 AC 891 ms
86,308 KB
testcase_28 AC 887 ms
85,744 KB
testcase_29 AC 888 ms
85,752 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from itertools import permutations
import sys
import random
import time
input = lambda: sys.stdin.readline().rstrip()
random.seed(0)

start = time.time()
N, M = map(int, input().split())
AB = [list(map(int, input().split())) for _ in range(N)]

def fit(X):
  labels_ = []
  now = 0
  while len(labels_) < len(X):
    labels_.append(now)
    now += 1
    if now == 8:
      now = 0
  random.shuffle(labels_)
  labels_prev = [0]*len(X)
  count = 0
  cluster_centers_ = [(0, 0)] * 8
  while count < 10:
    syuukei = [[] for _ in range(8)]
    for i in range(len(X)):
      syuukei[labels_[i]].append(X[i])
    for i,l in enumerate(syuukei):
      if l:
        x, y = sum(x for x,y in l)//len(l), sum(y for x,y in l)//len(l)
      else:
        x, y = random.randint(0, 1000), random.randint(0, 1000)
      cluster_centers_[i] = (x, y)
    labels_prev = labels_[:]
    for i in range(len(X)):
      dist = 10**18
      for j in range(8):
        tmp = (X[i][0] - cluster_centers_[j][0])**2 + (X[i][1] - cluster_centers_[j][1])**2
        if tmp < dist:
          dist = tmp
          labels_[i] = j
    count += 1
  return labels_, cluster_centers_


def main():
  labels, centers = fit(AB)
  def dist(a, b):
    return (a[0] - b[0])**2 + (a[1] - b[1])**2

  def calc(ans):
    score = 0
    for i in range(len(ans)-1):
      type_pre, indx_pre = ans[i]
      type_now, indx_now = ans[i+1]
      if type_pre == 1 and type_now == 1:
        score += 25*dist(AB[indx_pre], AB[indx_now])
      if type_pre == 1 and type_now == 2:
        score +=  5*dist(AB[indx_pre], centers[indx_now])
      if type_pre == 2 and type_now == 1:
        score +=  5*dist(centers[indx_pre], AB[indx_now])
      if type_pre == 2 and type_now == 2:
        score +=  1*dist(centers[indx_pre], centers[indx_now])
    return score

  ANS = []
  L = [[] for _ in range(8)]
  for i, cluster_number in enumerate(labels):
    L[cluster_number].append(i)

  def calc_p(p):
    score = dist((AB[0]), centers[p[0]])
    for i in range(len(centers)-1):
      score += dist(centers[p[i]], centers[p[i+1]])
    score += dist(centers[p[-1]], AB[0])
    return score

  vest_p = None
  vest = 1<<30
  for p in permutations(range(8)):
    score = calc_p(p)
    if score < vest:
      vest = score
      vest_p = p[:]
  if time.time() - start > 0.9:
    return 1<<30, -1, -1
  ans = [(1, 0)]
  for i in vest_p:
    if i >= len(L): continue
    ans.append((2, i))
    lim = len(L[i])
    j = 0
    while j < lim:
      d1 = 5 * dist(AB[L[i][j]], centers[i])
      d2 = 1<<30
      if j+1 < len(L[i]):
        d2 = 25* dist(AB[L[i][j]], AB[L[i][j+1]])
      if d2 < d1:
        ans.append((1, L[i][j]))
      else:
        ans.append((1, L[i][j]))
        ans.append((2, i))
      j += 1
  ans.append((1, 0))
  score = calc(ans)
  return score, ans, centers

ANS = []
CENTERS = []
vest = 1<<30

cnt = 0
while time.time() - start < 0.8:
  cnt += 1
  tmp, ans, centers = main()
  if tmp < vest:
    vest = tmp
    ANS = ans[:]
    CENTERS = centers
    print(round(10**9 / (1000 + vest**.5)), file=sys.stderr)
print(cnt, file=sys.stderr)

for c,d in CENTERS:
  print(c, d)
print(len(ANS))
for a,b in ANS:
  print(a, b+1)

0