結果

問題 No.5007 Steiner Space Travel
ユーザー tabae326tabae326
提出日時 2023-04-30 13:23:26
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 953 ms / 1,000 ms
コード長 21,666 bytes
コンパイル時間 5,825 ms
コンパイル使用メモリ 285,956 KB
実行使用メモリ 4,376 KB
スコア 8,980,632
最終ジャッジ日時 2023-04-30 13:24:04
合計ジャッジ時間 37,651 ms
ジャッジサーバーID
(参考情報)
judge16 / judge12
純コード判定しない問題か言語
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 952 ms
4,372 KB
testcase_01 AC 952 ms
4,376 KB
testcase_02 AC 952 ms
4,372 KB
testcase_03 AC 953 ms
4,372 KB
testcase_04 AC 953 ms
4,372 KB
testcase_05 AC 953 ms
4,368 KB
testcase_06 AC 952 ms
4,372 KB
testcase_07 AC 953 ms
4,368 KB
testcase_08 AC 953 ms
4,368 KB
testcase_09 AC 952 ms
4,372 KB
testcase_10 AC 952 ms
4,372 KB
testcase_11 AC 953 ms
4,368 KB
testcase_12 AC 953 ms
4,372 KB
testcase_13 AC 952 ms
4,372 KB
testcase_14 AC 953 ms
4,368 KB
testcase_15 AC 952 ms
4,368 KB
testcase_16 AC 952 ms
4,368 KB
testcase_17 AC 952 ms
4,372 KB
testcase_18 AC 952 ms
4,372 KB
testcase_19 AC 952 ms
4,372 KB
testcase_20 AC 952 ms
4,368 KB
testcase_21 AC 953 ms
4,368 KB
testcase_22 AC 952 ms
4,368 KB
testcase_23 AC 953 ms
4,368 KB
testcase_24 AC 953 ms
4,372 KB
testcase_25 AC 953 ms
4,368 KB
testcase_26 AC 952 ms
4,368 KB
testcase_27 AC 952 ms
4,372 KB
testcase_28 AC 953 ms
4,372 KB
testcase_29 AC 953 ms
4,368 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <sys/time.h>
#include <atcoder/all>

using namespace std;
int Loop;

#pragma region prototype_declaration
/* ============================================== 
    プロトタイプ宣言はここから
   ============================================== */

/*乱数生成器*/
struct RandGenerator {
    random_device seed_gen;
    mt19937 engine;
    mt19937_64 engine64;
    static const int pshift = 1000000000;
    RandGenerator() : engine(seed_gen()), engine64(seed_gen()) {}
    /*mod以下の乱数を返す(32bit)*/
    int rand(int mod) {
        return engine() % mod;
    }
    /*mod以下の乱数を返す(64bit)*/
    long long randll(long long mod) {
        return engine64() % mod;
    } 
    /*確率pでTrueを返す*/
    bool pjudge(double p) {
        int p_int;
        if(p > 1) p_int = pshift;
        else p_int = p * pshift;
        return rand(pshift) < p_int;
    }
} ryuka;

/*タイマー*/
struct Timer {
    double global_start;
    /*現在の時刻を返す*/
    double gettime() {
        struct timeval tv;
        gettimeofday(&tv, NULL);
        return tv.tv_sec + tv.tv_usec * 1e-6;
    }
    void init() {
        global_start = gettime();
    }
    /*プログラム開始からの経過時間を返す*/
    double elapsed() {
        return gettime() - global_start;
    }
} toki;

struct Node {
    int x, y, id, is_planet;
    Node() {};
    Node(int, int, int, int);
};

struct Input {
    /*TODO: ここに入力変数を定義する*/
    int n, m;
    const int a = 5;
    vector<Node> planets;
    void read();
} input;

struct Output {
    /*TODO: ここに出力変数を定義する*/
    vector<Node> stations;
    vector<Node> route; 
    Output();
    void print();
};

/*解を管理するクラス*/
struct State {
    Output output;
    long long length;
    long long score;
    State() : score(0) {}
    static State initState();
    static State generateState(const State& input_state);
    void changeState(double, int&, int, int);
    void changeStateClimb(int&, int, int);
};

namespace Utils {
    vector<vector<int>> planetsDist;
    void initPlanetsDist();
    int calcSquareDistOnlyPlanets(const Node& a, const Node& b);
    int calcSquareDist(const Node& a, const Node& b);
    int calcWeightedSquareDist(const Node& a, const Node& b);
    bool isPlanet(const Node& node);
    bool isStart(const Node& node);
    pair<long long, long long> calcScore(const Output& output);
    long long calcScoreFromLength(long long length);
    vector<Node> initStationsKMeans();
    vector<Node> initStationsGreedy(const vector<Node>& route);
    vector<Node> solveInsertedTSP();
    vector<Node> solveNNTSP();
    vector<Node> goThroughStations(vector<Node>, const vector<Node>& stations, int);
    pair<vector<Node>, vector<Node>> optimizeStations(const vector<Node>&);
};

/*イテレーション管理クラス*/
template<class STATE>
struct IterationControl {
    int iteration_counter;
    int swap_counter;
    double start_time;
    IterationControl() : iteration_counter(0), swap_counter(0) {}
    /*山登り法*/
    STATE climb(int iter_max, STATE initial_state) {
        STATE best_state = initial_state;
        #ifdef DEBUG
        cerr << "[INFO] - IterationControl::climb - Starts climbing...\n";
        #endif
        const int rsize = best_state.output.route.size();
        for(int it = 0; it < iter_max; it++) {
            for(int i = 0; i < rsize-1; i++) {
                for(int j = i+2; j < rsize-1; j++) {
                    iteration_counter++;
                    best_state.changeStateClimb(swap_counter, i, j);
                }
            }
        }
        #ifdef DEBUG
        cerr << "[INFO] - IterationControl::climb - Iterated " << iteration_counter << " times and swapped " << swap_counter << " times.\n";
        #endif
        return best_state;
    }
    /*焼きなまし法*/
    STATE anneal(double time_limit, double temp_start, double temp_end, STATE initial_state) {
        assert(temp_start >= temp_end);
        start_time = toki.gettime();
        STATE best_state = initial_state;
        double elapsed_time = 0;
        #ifdef DEBUG
        cerr << "[INFO] - IterationControl::anneal - Starts annealing...\n";
        #endif
        const int rsize = best_state.output.route.size();
        while(elapsed_time < time_limit) {
            double normalized_time = elapsed_time / time_limit;
            double temp_current = pow(temp_start, 1.0 - normalized_time) * pow(temp_end, normalized_time);
            for(int i = 0; i < rsize-1; i++) {
                for(int j = 0; j < rsize-1; j++) {
                    iteration_counter++;
                    best_state.changeState(temp_current, swap_counter, i, j);
                }
            }
            elapsed_time = toki.gettime() - start_time;
        }
        #ifdef DEBUG
        cerr << "[INFO] - IterationControl::anneal - Iterated " << iteration_counter << " times and swapped " << swap_counter << " times.\n";
        #endif
        return best_state;
    }
};


/* ============================================== 
    プロトタイプ宣言はここまで
   ============================================== */

#pragma endregion prototype_declaration

Node::Node(int x, int y, int id, int is_planet) : 
    x(x), y(y), id(id), is_planet(is_planet) {
        ;
}

/*TODO: ここで入力を受け取る*/
void Input::read() {
    cin >> n >> m;
    planets.resize(n);
    for(int i = 0; i < n; i++) {
        cin >> planets[i].x >> planets[i].y;
        planets[i].id = i;
        planets[i].is_planet = 1;
    }
}

/*TODO:ここで出力変数を初期化する。vectorのメモリ確保など*/
Output::Output() {
}

/*TODO:ここで答えを出力する*/
void Output::print() {
    assert(stations.size() == input.m);
    for(auto e : stations) {
        cout << e.x << " " << e.y << endl;
    }
    cout << route.size() << endl;
    for(auto e : route) {
        cout << (Utils::isPlanet(e) ? 1 : 2) << " " << e.id + 1 << endl;
    }
}

/*TODO: ここで初期解を作成する*/
State State::initState() {
    State res;
    res.output.route = Utils::solveInsertedTSP();
//    res.output.route = Utils::solveNNTSP();
    auto [score, length] = Utils::calcScore(res.output);
    res.score = score;
    res.length = length;
    return res;
}

void State::changeState(double temp_current, int &swap_counter, int i, int j) {
    if(i > j) swap(i, j);
    int i2 = i+1;
    int j2 = j+1;
    bool chk = (i != j) && (i != j2) && (j != i2); 
    if(!chk) return;
    long long org_score = score;
    long long new_len = length;
    new_len -= Utils::calcWeightedSquareDist(output.route[i], output.route[i2]);
    new_len -= Utils::calcWeightedSquareDist(output.route[j], output.route[j2]);
    new_len += Utils::calcWeightedSquareDist(output.route[i], output.route[j]);
    new_len += Utils::calcWeightedSquareDist(output.route[j2],output.route[i2]);
    long long new_score = Utils::calcScoreFromLength(new_len);
    long long delta = new_score - org_score;
    if(delta > 0 || ryuka.pjudge(exp(1.0 * delta / temp_current)) ) {
        swap_counter++;
        reverse(output.route.begin() + i2, output.route.begin() + j2);
        length = new_len;
        score = Utils::calcScoreFromLength(length);        
    }
}

void State::changeStateClimb(int &swap_counter, int i, int j) {
    const int i2 = i+1;
    const int j2 = j+1;
    const long long org_score = score;
    long long new_len = length;
    new_len -= Utils::calcWeightedSquareDist(output.route[i], output.route[i2]);
    new_len -= Utils::calcWeightedSquareDist(output.route[j], output.route[j2]);
    new_len += Utils::calcWeightedSquareDist(output.route[i], output.route[j]);
    new_len += Utils::calcWeightedSquareDist(output.route[j2],output.route[i2]);
    //const long long new_score = Utils::calcScoreFromLength(new_len);
    //const long long delta = new_score - org_score;
    const long long delta = new_len - length;
    if(delta < 0) {
        swap_counter++;
        reverse(output.route.begin() + i2, output.route.begin() + j2);
        length = new_len;
    }
}


/*TODO: ここでinput_stateを変化させた解を作る(局所探索)*/
State State::generateState(const State& input_state) {
    State res = input_state;
    int i = ryuka.rand(res.output.route.size() - 1);
    int j = ryuka.rand(res.output.route.size() - 1);
    if(j < i) swap(i, j);
    int i2 = i+1;
    int j2 = j+1;
    bool chk = (i != j) && (i != j2) && (j != i2); 
    if(chk) {
        res.length -= Utils::calcWeightedSquareDist(res.output.route[i], res.output.route[i2]);
        res.length -= Utils::calcWeightedSquareDist(res.output.route[j], res.output.route[j2]);
        res.length += Utils::calcWeightedSquareDist(res.output.route[i], res.output.route[j]);
        res.length += Utils::calcWeightedSquareDist(res.output.route[j2], res.output.route[i2]);
        reverse(res.output.route.begin() + i2, res.output.route.begin() + j2);
        res.score = Utils::calcScoreFromLength(res.length);
    }
    return res;
}


int Utils::calcSquareDist(const Node& a, const Node& b) {
    const int dx = a.x - b.x;
    const int dy = a.y - b.y;
    return dx * dx + dy * dy;
}

int Utils::calcSquareDistOnlyPlanets(const Node& a, const Node& b) {
    return Utils::planetsDist[a.id][b.id];
}

int Utils::calcWeightedSquareDist(const Node& a, const Node& b) {
    const int dx = a.x - b.x;
    const int dy = a.y - b.y;
    const int s = dx * dx + dy * dy;
    int res;
    if(Utils::isPlanet(a) && Utils::isPlanet(b)) res = s * input.a * input.a;
    else if(!Utils::isPlanet(a) && !Utils::isPlanet(b)) res = s;
    else res = s * input.a; 
    return res;
}

bool Utils::isStart(const Node& node) {
    return isPlanet(node) && node.id == 0;
}

bool Utils::isPlanet(const Node& node) {
    return node.is_planet;
}

/*TODO: ここでスコアを計算する*/
pair<long long, long long> Utils::calcScore(const Output& output) {
    long long sum = 0;
    const auto& route = output.route;
    for(int i = 0; i < route.size() - 1; i++) {
        const Node& cur = route[i];
        const Node& nxt = route[i+1];
        const long long d2 = Utils::calcSquareDist(cur, nxt);
        if(isPlanet(cur) && isPlanet(nxt)) {
            sum += input.a * input.a * d2; 
        } else if(!isPlanet(cur) && !isPlanet(nxt)) {
            sum += d2;
        } else {
            sum += input.a * d2;
        }
    }
    long long res = (long long)(1e9 / (1e3 + sqrt(sum)));
    return {res, sum};
}

long long Utils::calcScoreFromLength(long long length) {
    long long res = (long long)(1e9 / (1e3 + sqrt(length)));
    return res;
}

void Utils::initPlanetsDist() {
    Utils::planetsDist.resize(input.n, vector<int>(input.n));
    for(int i = 0; i < input.n; i++) {
        for(int j = i; j < input.n; j++) {
            const int dist = Utils::calcSquareDist(input.planets[i], input.planets[j]);
            Utils::planetsDist[i][j] = dist;
            Utils::planetsDist[j][i] = dist;
        }
    }
}

vector<Node> Utils::solveNNTSP() {
    vector<Node> route;
    route.reserve(input.n+1);
    vector<Node> nodes(input.n);
    for(int i = 0; i < input.n; i++) nodes[i] = input.planets[i]; 
    vector<bool> seen(input.n, false);
    route.emplace_back(input.planets.front());  
    seen[0] = true;
    int prev = 0;
    for(int it = 0; it < input.n-1; it++) {
        int min_dist = 1<<30, min_id = -1;
        for(int i = 0; i < input.n; i++) {
            if(!seen[i]) {
                int dist = Utils::calcSquareDistOnlyPlanets(nodes[prev], nodes[i]);
                if(min_dist > dist) {
                    min_dist = dist;
                    min_id = i;
                }
            }
        }
        seen[min_id] = true;
        prev = min_id;
        route.emplace_back(input.planets[min_id]);
    }
    route.emplace_back(input.planets.front());
    return route;
}

vector<Node> Utils::solveInsertedTSP() {
    vector<Node> route;
    route.reserve(input.n+1);
    vector<Node> nodes(input.n-1);
    for(int i = 0; i < input.n-1; i++) nodes[i] = input.planets[i+1]; 
    route.emplace_back(input.planets.front());
    route.emplace_back(input.planets.front());
    shuffle(nodes.begin(), nodes.end(), ryuka.engine);
    for(auto e: nodes) {
        int min_dist = 1<<30, min_id = -1;
        for(int i = 0; i < route.size() - 1; i++) {
            int dist = Utils::calcSquareDistOnlyPlanets(e, route[i]) + Utils::calcSquareDistOnlyPlanets(e, route[i+1]);
            if(min_dist > dist) {
                min_dist = dist;
                min_id = i + 1;
            }
        }
        assert(min_id != -1);
        route.insert(route.begin() + min_id, e);
    }
    return route;
}

vector<Node> Utils::initStationsGreedy(const vector<Node>& route) {
    vector<Node> res;
    vector<bool> seen(route.size(), false);
    vector<int> unchecked_path(route.size()-1);
    iota(unchecked_path.begin(), unchecked_path.end(), 0);
    for(int v = 0; v < input.m; v++) {
        long long max_dist = 0, max_x = -1, max_y = -1;
        for(int x = 25; x <= 975; x += 50) {
            for(int y = 25; y <= 975; y += 50) {
                long long dist = 0;
                for(int i : unchecked_path) {
                    int cur_dist = Utils::calcSquareDistOnlyPlanets(route[i], route[i+1]) * input.a;
                    int tmp_dist = Utils::calcSquareDist(route[i], Node(x, y, v, false)) + Utils::calcSquareDist(route[i+1], Node(x, y, v, false));
                    if(cur_dist > tmp_dist) {
                        dist += cur_dist - tmp_dist;
                    }
                }       
                if(dist > max_dist) {
                    max_dist = dist;
                    max_x = x;
                    max_y = y;
                }
            }
        }
        assert(max_x != -1);
        for(int i = 0; i < route.size() - 1; i++) {
            int cur_dist = Utils::calcSquareDistOnlyPlanets(route[i], route[i+1]) * input.a;
            int tmp_dist = Utils::calcSquareDist(route[i], Node(max_x, max_y, v, false)) + Utils::calcSquareDist(route[i+1], Node(max_x, max_y, v, false));        
            if(cur_dist > tmp_dist) {
                seen[i] = true;
            }
        }
        unchecked_path.clear();
        for(int i = 0; i < route.size()-1; i++) if(!seen[i]) unchecked_path.push_back(i);
        res.push_back(Node(max_x, max_y, v, false));
    }
    return res;
}

vector<Node> Utils::initStationsKMeans() {
    vector<int> cluster(input.n, 0);
    for(int i = 0; i < input.n; i++) {
        cluster[i] = ryuka.rand(input.m);
    }
    vector<int> prev;
    int count = 0;
    const int iter_max = 100;
    for(int it = 0; it < iter_max; it++) {
        vector<long long> w_x(input.m, 0);
        vector<long long> w_y(input.m, 0);
        vector<int> num(input.m, 0);
        for(int i = 0; i < input.n; i++) {
            w_x[cluster[i]] += input.planets[i].x;
            w_y[cluster[i]] += input.planets[i].y;
            num[cluster[i]]++;
        }
        for(int i = 0; i < input.m; i++) {
            if(num[i] > 0) {
                w_x[i] /= num[i];
                w_y[i] /= num[i];
            }
        }
        for(int i = 0; i < input.n; i++) {
            int min_dist = 1<<30, min_id = -1;
            for(int j = 0; j < input.m; j++) {
                int dist = Utils::calcSquareDist(input.planets[i], Node(w_x[j], w_y[j], -1, 0));
                if(dist < min_dist) {
                    min_dist = dist;
                    min_id = j;
                }
            }
            assert(min_id != -1);
            cluster[i] = min_id;
        }
        count++;
    }
    vector<Node> res(input.m);
    vector<long long> w_x(input.m, 0);
    vector<long long> w_y(input.m, 0);
    vector<int> num(input.m, 0);
    for(int i = 0; i < input.n; i++) {
        w_x[cluster[i]] += input.planets[i].x;
        w_y[cluster[i]] += input.planets[i].y;
        num[cluster[i]]++;
    }
    int zid = 0;
    for(int i = 0; i < input.m; i++) {
        if(num[i] > 0) {
            w_x[i] /= num[i];
            w_y[i] /= num[i];
        } else {
            w_x[i] = input.planets[zid].x;
            w_y[i] = input.planets[zid].y;
            zid++;
            /*
            int z = ryuka.rand(input.n);
            w_x[i] = clamp(input.planets[z].x + ryuka.rand(100) - 50, 1, 999);
            w_y[i] = clamp(input.planets[z].y + ryuka.rand(100) - 50, 1, 999);
            */
        }
    }
    for(int i = 0; i < input.m; i++) {
        res[i] = Node(w_x[i], w_y[i], i, 0);
    }   
    return res;
}

vector<Node> Utils::goThroughStations(vector<Node> route, const vector<Node>& stations, int iter_max) {
    vector<Node> nodes;
    nodes.reserve(input.n + input.m);
    for(auto e: input.planets) nodes.emplace_back(e);
    for(auto e: stations) nodes.emplace_back(e);
    for(int it = 0; it < iter_max; it++) {
        vector<Node> res;
        for(int i = 0; i < route.size() - 1; i++) {
            const Node& cur = route[i];
            const Node& nxt = route[i+1];
            res.emplace_back(cur);
            int min_dist = Utils::calcWeightedSquareDist(cur, nxt);
            int min_id = -1;
            for(int j = 0; j < nodes.size(); j++) {
                int dist = Utils::calcWeightedSquareDist(cur, nodes[j])
                            + Utils::calcWeightedSquareDist(nodes[j], nxt);
                if(min_dist > dist) {
                    min_dist = dist;
                    min_id = j;
                }
            }
            if(min_id >= 0) {
                res.emplace_back(nodes[min_id]);
            }
        }
        res.emplace_back(route.back());
        route = res;
    }
    return route;
}

pair<vector<Node>, vector<Node>> Utils::optimizeStations(const vector<Node>& route) {
    vector<int> w_x(input.m, 0);
    vector<int> w_y(input.m, 0);
    vector<int> num(input.m, 0);
    for(int i = 0; i < route.size(); i++) {
        if(Utils::isPlanet(route[i])) {
            if(i - 1 >= 0 && !Utils::isPlanet(route[i-1])) {
                w_x[route[i-1].id] += route[i].x;
                w_y[route[i-1].id] += route[i].y;
                num[route[i-1].id]++;
            } 
            if(i + 1 < route.size() && !Utils::isPlanet(route[i+1])){
                w_x[route[i+1].id] += route[i].x;
                w_y[route[i+1].id] += route[i].y;
                num[route[i+1].id]++;
            }
        }
    }
    vector<Node> stations;
    stations.reserve(input.m);
    for(int i = 0; i < input.m; i++) {
        if(num[i] > 0) {
            w_x[i] /= num[i];
            w_y[i] /= num[i];
        }
        stations.emplace_back(Node(w_x[i], w_y[i], i, 0));
    }
    vector<Node> ret_route = route;
    for(int i = 0; i < ret_route.size(); i++) {
        if(!Utils::isPlanet(ret_route[i])) {
            ret_route[i].x = stations[ret_route[i].id].x;
            ret_route[i].y = stations[ret_route[i].id].y;
        }
    }
    return {ret_route, stations};
} 

int main(int argc, char* argv[]) {
    toki.init();
    input.read();   
    Utils::initPlanetsDist();
    long long best_score = 0, best_pre_score = 0;
    State best, best_pre;
    IterationControl<State> sera;
    for(Loop = 0; Loop < 10000000; Loop++) {
        if(toki.elapsed() > 0.8) break;
        //State ans = sera.anneal(0.01, 1e5, 1, State::initState());
        //State ans = sera.climb(0.0005, State::initState());
        //State ans = State::initState();
        State ans  = sera.climb(6, State::initState());
        ans.output.stations = Utils::initStationsGreedy(ans.output.route);
        ans.output.route = Utils::goThroughStations(ans.output.route, ans.output.stations, 2);
        auto [_route, _stations] = Utils::optimizeStations(ans.output.route);
        ans.output.route = move(_route);
        ans.output.stations = move(_stations);
        ans.score = Utils::calcScore(ans.output).first;
        if(ans.score > best_pre_score) {
            best_pre_score = ans.score;
        }
        //ans = sera.anneal(0.01, 1e5, 1, ans);
        //ans = sera.climb(0.0005, ans);
        auto f = [&](State& ans) -> bool {
            ans = sera.climb(3, ans);
            ans.output.route = Utils::goThroughStations(ans.output.route, ans.output.stations, 2);
            auto [route, stations] = Utils::optimizeStations(ans.output.route);
            ans.output.route = move(route);
            ans.output.stations = move(stations);
            ans.score = Utils::calcScore(ans.output).first;
            if(ans.score > best_score) {
                best_score = ans.score;
                best = ans;
                return true;
            }
            return false;
        };
        int prev_ret = f(ans);
        if(prev_ret && Loop > 50) {
            for(int chal = 0; chal < 5; chal++) {
                f(best);
            }
        }
    }
    cerr << "[DEBUG] - main - best_score = " << best_score << "\n";
    while(toki.elapsed() < 0.95) {
        State ans = sera.climb(3, best);
        ans.output.route = Utils::goThroughStations(ans.output.route, ans.output.stations, 2);
        auto [route, stations] = Utils::optimizeStations(ans.output.route);
        ans.output.route = move(route);
        ans.output.stations = move(stations);
        ans.score = Utils::calcScore(ans.output).first;
        if(ans.score > best_score) {
            best_score = ans.score;
            best = ans;
        }
    }
    best.output.print();
    cerr << "[INFO] - main - Loop = " << Loop << "\n";
    cerr << "[INFO] - main - MyScore = " << best.score << "\n";
}
0