結果

問題 No.5022 XOR Printer
ユーザー ra5anchor
提出日時 2025-07-31 23:56:01
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,868 ms / 2,000 ms
コード長 13,977 bytes
コンパイル時間 385 ms
コンパイル使用メモリ 82,364 KB
実行使用メモリ 92,692 KB
スコア 5,206,925,670
最終ジャッジ日時 2025-07-31 23:57:41
合計ジャッジ時間 95,987 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 50
権限があれば一括ダウンロードができます

ソースコード

diff #

import copy
import random
import itertools
from time import perf_counter
import argparse
import sys
import math

MAX = 10**8

class TimeKeeper:
    def __init__(self):
        self.start_time = perf_counter()
    def is_time_over(self, LIMIT):
        return (perf_counter() - self.start_time) >= LIMIT
    def time_now(self):
        return (perf_counter() - self.start_time)

#
# ----- TSP 用ヘルパー関数(Christofides + 2-opt) -----
#

def manhattan_matrix(coords):
    """座標リストからマンハッタン距離行列を返す"""
    n = len(coords)
    D = [[0]*n for _ in range(n)]
    for i in range(n):
        xi, yi = coords[i]
        for j in range(i+1, n):
            xj, yj = coords[j]
            d = abs(xi - xj) + abs(yi - yj)
            D[i][j] = d
            D[j][i] = d
    return D

def prim_mst(D):
    """
    Prim 法で MST の辺リストを返す
    戻り値は [(u, v), …] のタプルリスト
    """
    n = len(D)
    in_mst = [False]*n
    key = [float('inf')]*n
    parent = [-1]*n
    key[0] = 0
    edges = []
    for _ in range(n):
        u = min((i for i in range(n) if not in_mst[i]), key=lambda i: key[i])
        in_mst[u] = True
        if parent[u] != -1:
            edges.append((u, parent[u]))
        for v in range(n):
            if not in_mst[v] and D[u][v] < key[v]:
                key[v] = D[u][v]
                parent[v] = u
    return edges

def greedy_matching(odd, D):
    """
    奇次数頂点の貪欲マッチング
    未マッチ頂点同士の最小距離ペアを順に組む
    """
    pairs = []
    unmatched = set(odd)
    candidates = [(D[i][j], i, j) for i, j in itertools.combinations(odd, 2)]
    candidates.sort()
    for _, i, j in candidates:
        if i in unmatched and j in unmatched:
            pairs.append((i, j))
            unmatched.remove(i)
            unmatched.remove(j)
            if not unmatched:
                break
    return pairs

def build_multigraph(n, mst_edges, matching_edges):
    """
    MST 辺とマッチング辺を合わせた多重グラフを隣接辞書で構築
    g[u][v] = 辺の重複数
    """
    g = {i: {} for i in range(n)}
    def add(u, v):
        g[u][v] = g[u].get(v, 0) + 1
        g[v][u] = g[v].get(u, 0) + 1
    for u, v in mst_edges + matching_edges:
        add(u, v)
    return g

def eulerian_tour(graph, start):
    """
    ヒアホルツァーのアルゴリズムでオイラー巡回を返す
    graph は build_multigraph の出力形式
    """
    g = {u: dict(graph[u]) for u in graph}
    stack = [start]
    path = []
    while stack:
        u = stack[-1]
        if g[u]:
            v = next(iter(g[u]))
            # 辺 u–v を削除
            g[u][v] -= 1
            if g[u][v] == 0: del g[u][v]
            g[v][u] -= 1
            if g[v][u] == 0: del g[v][u]
            stack.append(v)
        else:
            path.append(stack.pop())
    return path

def make_hamiltonian_cycle(euler_path):
    """
    EulerPath から重複を除いてハミルトン巡回を作成し、
    始点を末尾にも付加して閉路にする
    """
    seen = set()
    cycle = []
    for u in euler_path:
        if u not in seen:
            seen.add(u)
            cycle.append(u)
    cycle.append(cycle[0])
    return cycle

def two_opt_cycle(cycle, D):
    """
    閉路 cycle に対して 2-opt を適用し改善
    cycle は [v0, v1, …, vn=v0]
    """
    n = len(cycle) - 1
    improved = True
    while improved:
        improved = False
        for i in range(1, n-1):
            for j in range(i+1, n):
                a, b = cycle[i-1], cycle[i]
                c, d = cycle[j], cycle[j+1]
                if D[a][b] + D[c][d] > D[a][c] + D[b][d]:
                    cycle[i:j+1] = reversed(cycle[i:j+1])
                    improved = True
                    break
            if improved:
                break
    return cycle

def solve_tsp(coords, start=0):
    """
    coords: [(x0,y0), …]
    start: 始点インデックス
    戻り値: coords上のインデックス列(始点を含む開路)
    """
    n = len(coords)
    D = manhattan_matrix(coords)
    mst = prim_mst(D)
    deg = [0]*n
    for u, v in mst:
        deg[u] += 1; deg[v] += 1
    odd = [i for i, d in enumerate(deg) if d % 2 == 1]
    matching = greedy_matching(odd, D)
    graph = build_multigraph(n, mst, matching)
    euler_path = eulerian_tour(graph, start)
    cycle = make_hamiltonian_cycle(euler_path)
    cycle = two_opt_cycle(cycle, D)
    # 開路にして返す
    return cycle[:-1]

#
# ----- ここまで TSP 用ヘルパー関数 -----
#

def main(DEBUG):
    tk = TimeKeeper()
    if DEBUG:
        LIMIT = 1.0
    else:
        LIMIT = 1.7

    def cal_score(A):
        N = 10
        score = 0
        for i in range(N):
            for j in range(N):
                score += A[i][j]
        return score

    def cal_score_sim(ANS):
        N = 10
        nowi, nowj, s = 0, 0, 0
        B = [row[:] for row in A]
        if len(ANS) > 1000:
            return -1
        for c in ANS:
            if c == "L":
                nowj -= 1
            elif c == "R":
                nowj += 1
            elif c == "U":
                nowi -= 1
            elif c == "D":
                nowj  # typo fix: should be nowj
                nowi += 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            if nowi < 0 or nowi >= N or nowj < 0 or nowj >= N:
                return -1
        score = 0
        for i in range(N):
            for j in range(N):
                score += B[i][j]
        return score

    def replay(ANS):
        N = 10
        nowi, nowj, s = 0, 0, 0
        B = [row[:] for row in A]
        if len(ANS) > 1000:
            return -1
        for c in ANS:
            if c == "L":
                nowj -= 1
            elif c == "R":
                nowj += 1
            elif c == "U":
                nowi -= 1
            elif c == "D":
                nowj  # typo fix: should be nowj
                nowi += 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            if nowi < 0 or nowi >= N or nowj < 0 or nowj >= N:
                return -1
        score = 0
        minv = float('inf')
        mini = minj = 0
        for i in range(N):
            for j in range(N):
                score += B[i][j]
                if B[i][j] < minv:
                    minv = B[i][j]
                    mini, minj = i, j
        maxi, maxj = mini, minj
        maxv = 0
        for di in range(-3, 4):
            for dj in range(-3, 4):
                ii, jj = mini+di, minj+dj
                if 0 <= ii < N and 0 <= jj < N and B[ii][jj] > maxv:
                    maxv = B[ii][jj]
                    maxi, maxj = ii, jj
        remain = 0
        dist1 = abs(nowi-maxi) + abs(nowj-maxj)
        dist2 = abs(maxi-mini) + abs(maxj-minj)
        while dist1 + dist2 + 3 > remain:
            c = ANS.pop()
            remain += 1
            if c == "L":
                nowj += 1
            elif c == "R":
                nowj -= 1
            elif c == "U":
                nowi += 1
            elif c == "D":
                nowj  # typo fix: should be nowj
                nowi -= 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            dist1 = abs(nowi-maxi) + abs(nowj-maxj)
            dist2 = abs(maxi-mini) + abs(maxj-minj)
        if s < 2**19:
            res = goto(nowi, nowj, maxi, maxj)
            ANS.extend(res)
            nowi, nowj = maxi, maxj
            ANS.append("C")
        res = goto(nowi, nowj, mini, minj)
        ANS.extend(res)
        nowi, nowj = mini, minj
        ANS.append("C")
        ANS.append("W")
        return ANS

    def goto(nowi, nowj, toi, toj):
        res = []
        di, dj = toi-nowi, toj-nowj
        if di > 0:
            res += ["D"]*di
        else:
            res += ["U"]*(-di)
        if dj > 0:
            res += ["R"]*dj
        else:
            res += ["L"]*(-dj)
        return res

    def get_order(w_pos, nowi, nowj):
        """
        Christofides + 2-opt による巡回順序を返す
        """
        if not w_pos:
            return []
        coords = [(nowi, nowj)] + list(w_pos)
        path = solve_tsp(coords, start=0)
        ordered = [coords[i] for i in path[1:]]
        return ordered

    def get_order2(w_pos):
        # zigzag
        idx = {}
        cnt = 0
        for i in range(10):
            if i % 2 == 0:
                for j in range(10):
                    idx[(i,j)] = cnt; cnt += 1
            else:
                for j in reversed(range(10)):
                    idx[(i,j)] = cnt; cnt += 1
        return sorted(w_pos, key=lambda x: idx[x])

    def solve(tk, LIMIT):
        N = 10
        X = [row[:] for row in A]
        nowi, nowj, s = 0, 0, 0
        actions = []
        for k in range(10): # k:keta
            if tk.is_time_over(LIMIT):
                break
            if len(actions) > 1000:
                break
            bestv = 0
            minturn = 10000
            nowi0, nowj0 = nowi, nowj
            s0 = s
            # max_loop = 100
            max_loop = 10
            for loop in range(max_loop):
                X1 = copy.deepcopy(X)
                actions1 = []
                nowi, nowj = nowi0, nowj0
                s = s0

                # sを設定する(色々試して
                xnow = s >> 20-1-k & 1
                kouho = []
                for i in range(N):
                    for j in range(N):
                        # if (X[i][j] >> (20-1-k)) & 1 == 1:
                        if (X1[i][j] >> 20-1-k+1 == (1 << k) - 1) and (X1[i][j] >> 20-1-k & 1 == 1-xnow):
                            ti, tj = i, j
                            d = abs(nowi-i) + abs(nowj-j)
                            kouho.append((d, ti, tj))
                kouho.sort()
                if len(kouho)==0:
                    print(f"not found kouho {k=} -> break", file=sys.stderr)
                    break
                
                # d, ti, tj = kouho[0]
                d, ti, tj = kouho[random.randrange(len(kouho))]

                # 目的地へ向かう
                res = goto(nowi, nowj, ti, tj)
                actions1.extend(res)
                nowi = ti
                nowj = tj

                # Copyして完成
                actions1.append("C")
                s ^= X1[nowi][nowj]
                
                # 書き込みする地点列挙
                w_pos = []
                for i in range(N):
                    for j in range(N):
                        if X1[i][j] ^ s > X1[i][j]:
                        # if (X[i][j] >> (20-1-k)) & 1 == 0:
                        # if (X[i][j] >> 20-1-k+1 == (1 << k) - 1) and (X[i][j] >> 20-1-k & 1 == 0):
                            w_pos.append((i, j))
                
                if len(w_pos) == 0:
                    print(f"not found w_pos {k=} -> break", file=sys.stderr)
                    break

                # 最大値の地点は使用せずに残す
                mxv = 0
                for (toi, toj) in w_pos:
                    v = X1[toi][toj]
                    if v > mxv:
                        mxv = v
                        mxi, mxj = toi, toj
                w_pos.remove((mxi, mxj))

                # 巡回する順番決め
                w_pos_ordered = get_order(w_pos, nowi, nowj)
                # w_pos_ordered = get_order2(w_pos)

                # 一つ残して巡回
                for (toi, toj) in w_pos_ordered:
                    res = goto(nowi, nowj, toi, toj)
                    actions1.extend(res)
                    nowi = toi
                    nowj = toj
                    # Print
                    actions1.append("W")
                    X1[nowi][nowj] ^= s

                # 最後の一つはsに書き込む
                toi, toj = mxi, mxj
                res = goto(nowi, nowj, toi, toj)
                actions1.extend(res)
                nowi = toi
                nowj = toj
                # Copy
                actions1.append("C")
                s ^= X1[nowi][nowj]
            
                # 暫定スコア
                v = 0
                for i in range(N):
                    for j in range(N):
                        v += X1[i][j]
                turn = len(actions1)
                # if turn < minturn:
                if v > bestv:
                    print(f"    best: {v=} {turn=} loop: {loop}", file=sys.stderr)
                    bestv = v
                    minturn = turn
                    best_actions = actions1[:]
                    best_X = copy.deepcopy(X1)
                    best_s = s
                    best_nowi, best_nowj = nowi, nowj            # 更新
            X = [row[:] for row in best_X]
            actions.extend(best_actions)
            s, nowi, nowj = best_s, best_nowi, best_nowj
        return actions

    # 入力読み込み
    N, T = map(int, input().split())
    A = [list(map(int, input().split())) for _ in range(N)]

    best_sc = -1
    best_ans = []
    LOOP = 0
    while not tk.is_time_over(LIMIT):
        LOOP += 1
        ANS = solve(tk, LIMIT)
        sc0 = cal_score_sim(ANS[:T])
        ANS = replay(ANS[:T])
        sc1 = cal_score_sim(ANS)
        if sc1 > best_sc:
            print(f"  BEST: {sc1}", file=sys.stderr)
            best_sc = sc1
            best_ans = ANS[:]
    print(f"SC: {best_sc}", file=sys.stderr)
    print(*best_ans, sep='\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Debug mode')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode')
    args = parser.parse_args()
    main(args.debug)
0