結果

問題 No.5022 XOR Printer
ユーザー ra5anchor
提出日時 2025-08-17 09:26:37
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,822 ms / 2,000 ms
コード長 13,655 bytes
コンパイル時間 435 ms
コンパイル使用メモリ 82,572 KB
実行使用メモリ 99,452 KB
スコア 5,207,899,019
最終ジャッジ日時 2025-08-17 09:28:13
合計ジャッジ時間 95,176 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 50
権限があれば一括ダウンロードができます

ソースコード

diff #

import copy
import random
from time import perf_counter
import argparse
import sys
import math

MAX = 10**8

class TimeKeeper:
    def __init__(self):
        self.start_time = perf_counter()
    def is_time_over(self, LIMIT):
        return (perf_counter() - self.start_time) >= LIMIT
    def time_now(self):
        return (perf_counter() - self.start_time)



###########################################
def main(DEBUG):
    tk = TimeKeeper()
    if DEBUG == True:
        LIMIT = 1.0
    else:
        LIMIT = 1.7

    def cal_score(A):
        N = 10
        score = 0
        for i in range(N):
            for j in range(N):
                score += A[i][j]
        return score

    def cal_score_sim(ANS):
        N = 10
        nowi = 0
        nowj = 0
        s = 0
        B = [[0]*N for _ in range(N)]
        for i in range(N):
            for j in range(N):
                B[i][j] = A[i][j]
        if len(ANS) > 1000:
            return -1
        for c in ANS:
            if c == "L":
                nowj -= 1
            elif c == "R":
                nowj += 1
            elif c == "U":
                nowi -= 1
            elif c == "D":
                nowi += 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            if nowi < 0 or nowi >= N or nowj < 0 or nowj >= N:
                print("outofrange", file=sys.stderr)
                return -1
        score = 0
        for i in range(N):
            for j in range(N):
                score += B[i][j]
        return score


    def replay(ANS):
        N = 10
        nowi = 0
        nowj = 0
        s = 0
        B = [[0]*N for _ in range(N)]
        for i in range(N):
            for j in range(N):
                B[i][j] = A[i][j]
        if len(ANS) > 1000:
            return -1
        for c in ANS:
            if c == "L":
                nowj -= 1
            elif c == "R":
                nowj += 1
            elif c == "U":
                nowi -= 1
            elif c == "D":
                nowi += 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            if nowi < 0 or nowi >= N or nowj < 0 or nowj >= N:
                return -1
        score = 0
        minv = 10**18
        for i in range(N):
            for j in range(N):
                score += B[i][j]
                if B[i][j] < minv:
                    minv = B[i][j]
                    mini, minj = i, j
        maxi, maxj = mini, minj
        # minの近傍で大きな値を探す
        maxv = 0
        for di in range(-3, 4):
        # for di in [-1,0,1]:
            for dj in range(-3, 4):
            # for dj in [-1,0,1]:
                ii = mini + di
                jj = minj + dj
                if ii < 0 or ii >= N or jj < 0 or jj >= N:
                    continue
                if B[ii][jj] > maxv:
                    maxv = B[ii][jj]
                    maxi = ii
                    maxj = jj
        remain = 0
        dist1 = abs(nowi-maxi) + abs(nowj-maxj)
        dist2 = abs(maxi-mini) + abs(maxj-minj)
        while dist1 + dist2 + 3 > remain:
            c = ANS.pop()
            remain += 1
            if c == "L":
                nowj += 1
            elif c == "R":
                nowj -= 1
            elif c == "U":
                nowi += 1
            elif c == "D":
                nowi -= 1
            elif c == "W":
                B[nowi][nowj] ^= s
            elif c == "C":
                s ^= B[nowi][nowj]
            dist1 = abs(nowi-maxi) + abs(nowj-maxj)
            dist2 = abs(maxi-mini) + abs(maxj-minj)

        # 目的地へ向かう
        if s < 2**19:
            res = goto(nowi, nowj, maxi, maxj)
            ANS.extend(res)
            nowi = maxi
            nowj = maxj
            # Copy
            ANS.append("C")
        # 目的地へ向かう
        res = goto(nowi, nowj, mini, minj)
        ANS.extend(res)
        nowi = mini
        nowj = minj
        # Copy
        ANS.append("C")
        # Print
        ANS.append("W")
        # return mini, minj, nowi, nowj, s
        # return score
        return ANS

    def goto(nowi, nowj, toi, toj):
        res = []
        di = toi - nowi
        dj = toj - nowj
        if di > 0:
            for _ in range(di):
                res.append("D")
        else:
            for _ in range(abs(di)):
                res.append("U")
        if dj > 0:
            for _ in range(dj):
                res.append("R")
        else:
            for _ in range(abs(dj)):
                res.append("L")
        return res

    # --- 距離ユーティリティ(マンハッタン距離) ---
    def _md(a, b):
        return abs(a[0]-b[0]) + abs(a[1]-b[1])

    def _path_len(order, start):
        if not order:
            return 0
        d = _md(start, order[0])
        for i in range(len(order)-1):
            d += _md(order[i], order[i+1])
        return d

    # --- 最近近傍で初期解(オープンパス) ---
    def _nearest_neighbor_path(points, start):
        if not points:
            return []
        rem = points[:]
        # start に最も近い点から始めると安定しやすい
        cur_idx = min(range(len(rem)), key=lambda i: _md(start, rem[i]))
        path = [rem.pop(cur_idx)]
        while rem:
            cur = path[-1]
            nxt_idx = min(range(len(rem)), key=lambda i: _md(cur, rem[i]))
            path.append(rem.pop(nxt_idx))
        return path

    # --- 2-opt(オープンパス版) ---
    def _two_opt_path(order, start, max_iter=20000):
        n = len(order)
        if n < 3:
            return order
        def edgelen(p, q):  # None は辺なし
            return 0 if (p is None or q is None) else _md(p, q)

        it = 0
        improved = True
        while improved and it < max_iter:
            improved = False
            it += 1
            # 端の入れ替えは i=0 とか k=n-1 も許す(オープンパスなのでOK)
            for i in range(0, n-1):
                A = start if i == 0 else order[i-1]
                B = order[i]
                for k in range(i+1, n):
                    C = order[k]
                    D = None if k == n-1 else order[k+1]
                    # 2-opt で [i..k] を反転する改善量
                    delta = edgelen(A, C) + edgelen(B, D) - (edgelen(A, B) + edgelen(C, D))
                    if delta < 0:
                        order[i:k+1] = reversed(order[i:k+1])
                        improved = True
                # 早めに次の外側ループへ(小改善でも反映)
                if improved:
                    break
        return order


    def get_order2(w_pos, nowi, nowj):
        """
        現在位置 (nowi, nowj) から、w_pos の全点を少ない手数で訪問する順番を返す。
        1) 最近近傍で初期解
        2) 2-opt(オープンパス版)で改善
        """
        start = (nowi, nowj)
        pts = w_pos[:]  # [(i,j), ...]
        # 初期解(最近近傍)
        path = _nearest_neighbor_path(pts, start)
        # 局所改善(2-opt)
        path = _two_opt_path(path, start, max_iter=10000)
        return path

    def get_order(w_pos, nowi, nowj):
        # greedy
        w = []
        while len(w_pos)>0:
            w_pos.sort(key=lambda x: -abs(x[0]-nowi)-abs(x[1]-nowj))
            (toi, toj) = w_pos.pop()
            w.append((toi, toj))
            nowi = toi
            nowj = toj
        return w            


    def solve(tk, LIMIT):
        X = [[0]*N for _ in range(N)]
        for i in range(N):
            for j in range(N):
                X[i][j] = A[i][j]
        nowi = 0
        nowj = 0
        s = 0
        actions = []
        
        #* 2進数で考える:上の桁から順番に操作数の限り行う
        # for k in range(20): # k:keta
        for k in range(10): # k:keta
            if tk.is_time_over(LIMIT):
                break
            if len(actions) > 1000:
                break
            bestv = 0
            minturn = 10000
            nowi0, nowj0 = nowi, nowj
            s0 = s
            # max_loop = 100
            max_loop = 10
            for loop in range(max_loop):
                X1 = copy.deepcopy(X)
                actions1 = []
                nowi, nowj = nowi0, nowj0
                s = s0

                # sを設定する(色々試して
                xnow = s >> 20-1-k & 1 # 現在のスタンプの k桁目のbit
                kouho = []
                for i in range(N):
                    for j in range(N):
                        # if (X[i][j] >> (20-1-k)) & 1 == 1:
                        #* k桁目のbitが現在のスタンプの逆である かつ k桁目より大きい部分が全て1である マスを候補とする
                        if (X1[i][j] >> 20-1-k+1 == (1 << k) - 1) and (X1[i][j] >> 20-1-k & 1 == 1-xnow):
                            ti, tj = i, j
                            d = abs(nowi-i) + abs(nowj-j)
                            kouho.append((d, ti, tj))
                kouho.sort()
                if len(kouho)==0:
                    print(f"not found kouho {k=} -> break", file=sys.stderr)
                    break
                
                # d, ti, tj = kouho[0]
                #* 候補の中からランダムに選ぶ
                d, ti, tj = kouho[random.randrange(len(kouho))]

                # 目的地へ向かう
                res = goto(nowi, nowj, ti, tj)
                actions1.extend(res)
                nowi = ti
                nowj = tj

                # Copyして完成
                actions1.append("C")
                s ^= X1[nowi][nowj]
                
                # 書き込みする地点列挙
                w_pos = []
                for i in range(N):
                    for j in range(N):
                        if X1[i][j] ^ s > X1[i][j]:
                        # if (X[i][j] >> (20-1-k)) & 1 == 0:
                        # if (X[i][j] >> 20-1-k+1 == (1 << k) - 1) and (X[i][j] >> 20-1-k & 1 == 0):
                            w_pos.append((i, j))
                
                if len(w_pos) == 0:
                    print(f"not found w_pos {k=} -> break", file=sys.stderr)
                    break

                # 最大値の地点は使用せずに残す
                mxv = 0
                for (toi, toj) in w_pos:
                    v = X1[toi][toj]
                    if v > mxv:
                        mxv = v
                        mxi, mxj = toi, toj
                w_pos.remove((mxi, mxj))

                # 巡回する順番決め
                w_pos_ordered = get_order2(w_pos, nowi, nowj)
                # w_pos_ordered = get_order2(w_pos)

                # 一つ残して巡回
                for (toi, toj) in w_pos_ordered:
                    res = goto(nowi, nowj, toi, toj)
                    actions1.extend(res)
                    nowi = toi
                    nowj = toj
                    # Print
                    actions1.append("W")
                    X1[nowi][nowj] ^= s

                # 最後の一つはsに書き込む
                toi, toj = mxi, mxj
                res = goto(nowi, nowj, toi, toj)
                actions1.extend(res)
                nowi = toi
                nowj = toj
                # Copy
                actions1.append("C")
                s ^= X1[nowi][nowj]
            
                # 暫定スコア
                v = 0
                for i in range(N):
                    for j in range(N):
                        v += X1[i][j]
                turn = len(actions1)
                # if turn < minturn:
                if v > bestv:
                    # print(f"best: {v=} {turn=} loop: {loop}", file=sys.stderr)
                    bestv = v
                    minturn = turn
                    best_actions = actions1[:]
                    best_X = copy.deepcopy(X1)
                    best_s = s
                    best_nowi, best_nowj = nowi, nowj
                    
            X = copy.deepcopy(best_X)
            actions.extend(best_actions)
            s = best_s
            nowi, nowj = best_nowi, best_nowj
            
            # print(f"{k=} skip ({nowi} {nowj}) turn {len(actions)}", file=sys.stderr)
                # 
            # print(f"{k=} {s=} {bin(s)=}", file=sys.stderr)
            # print(X, file=sys.stderr)
        return actions


    N, T = map(int, input().split())
    # N=10, T=1000
    A = [list(map(int, input().split())) for _ in range(N)]

    best_sc = 0
    LOOP = 0
    while True:
        LOOP += 1
        if tk.is_time_over(LIMIT):
            break

        ANS = solve(tk, LIMIT)

        sc0 = cal_score_sim(ANS[:T])

        # 最小値のマスを修正
        # T=1000時点での最小値(i,j)および(nowi,nowj)、s値を再計算
        ANS = replay(ANS[:T])

        # print(f"T: {len(ANS)}", file=sys.stderr)
        # ANS = ANS[:T]
        sc1 = cal_score_sim(ANS)

        if sc1 > best_sc:
            print(f"  BEST {LOOP=} {sc1=}", file=sys.stderr)
            best_sc = sc1
            best_ans = ANS


    # print(f"SC: {sc1} <- {sc0}", file=sys.stderr)
    # print(*ANS, sep='\n')
    print(f"SC: {best_sc}", file=sys.stderr)
    print(*best_ans, sep='\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Debug mode')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode')
    args = parser.parse_args()
    main(args.debug)
0