結果

問題 No.1216 灯籠流し/Lanterns
ユーザー yosupotyosupot
提出日時 2020-08-30 15:21:16
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,337 ms / 4,500 ms
コード長 17,128 bytes
コンパイル時間 2,600 ms
コンパイル使用メモリ 136,412 KB
最終ジャッジ日時 2025-01-13 22:37:58
ジャッジサーバーID
(参考情報)
judge5 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 48
権限があれば一括ダウンロードができます

ソースコード

diff #

//#pragma GCC optimize("Ofast")
//#pragma GCC target("avx")
//#undef LOCAL




#include <algorithm>

#include <array>

#include <bitset>

#include <cassert>

#include <complex>

#include <cstdio>

#include <cstring>

#include <iostream>

#include <map>

#include <numeric>

#include <queue>

#include <set>

#include <string>

#include <unordered_map>

#include <unordered_set>

#include <vector>

using namespace std;

using uint = unsigned int;
using ll = long long;
using ull = unsigned long long;
constexpr ll TEN(int n) { return (n == 0) ? 1 : 10 * TEN(n - 1); }
template <class T> using V = vector<T>;
template <class T> using VV = V<V<T>>;



#include <unistd.h>

struct Scanner {
    int fd = -1;
    char line[(1 << 15) + 1];
    size_t st = 0, ed = 0;
    void reread() {
        memmove(line, line + st, ed - st);
        ed -= st;
        st = 0;
        ed += ::read(fd, line + ed, (1 << 15) - ed);
        line[ed] = '\0';
    }
    bool succ() {
        while (true) {
            if (st == ed) {
                reread();
                if (st == ed) return false;
            }
            while (st != ed && isspace(line[st])) st++;
            if (st != ed) break;
        }
        if (ed - st <= 50) {
            bool sep = false;
            for (size_t i = st; i < ed; i++) {
                if (isspace(line[i])) {
                    sep = true;
                    break;
                }
            }
            if (!sep) reread();
        }
        return true;
    }
    template <class T, enable_if_t<is_same<T, string>::value, int> = 0>
    bool read_single(T& ref) {
        if (!succ()) return false;
        while (true) {
            size_t sz = 0;
            while (st + sz < ed && !isspace(line[st + sz])) sz++;
            ref.append(line + st, sz);
            st += sz;
            if (!sz || st != ed) break;
            reread();
        }
        return true;
    }
    template <class T, enable_if_t<is_integral<T>::value>* = nullptr>
    bool read_single(T& ref) {
        if (!succ()) return false;
        bool neg = false;
        if (line[st] == '-') {
            neg = true;
            st++;
        }
        ref = T(0);
        while (isdigit(line[st])) {
            ref = 10 * ref + (line[st++] & 0xf);
        }
        if (neg) ref = -ref;
        return true;
    }
    template <class T> bool read_single(V<T>& ref) {
        for (auto& d : ref) {
            if (!read_single(d)) return false;
        }
        return true;
    }
    void read() {}
    template <class H, class... T> void read(H& h, T&... t) {
        bool f = read_single(h);
        assert(f);
        read(t...);
    }
    int read_unsafe() { return 0; }
    template <class H, class... T> int read_unsafe(H& h, T&... t) {
        bool f = read_single(h);
        if (!f) return 0;
        return 1 + read_unsafe(t...);
    }
    Scanner(FILE* fp) : fd(fileno(fp)) {}
};

struct Printer {
  public:
    template <bool F = false> void write() {}
    template <bool F = false, class H, class... T>
    void write(const H& h, const T&... t) {
        if (F) write_single(' ');
        write_single(h);
        write<true>(t...);
    }
    template <class... T> void writeln(const T&... t) {
        write(t...);
        write_single('\n');
    }

    Printer(FILE* _fp) : fp(_fp) {}
    ~Printer() { flush(); }

  private:
    static constexpr size_t SIZE = 1 << 15;
    FILE* fp;
    char line[SIZE], small[50];
    size_t pos = 0;
    void flush() {
        fwrite(line, 1, pos, fp);
        pos = 0;
    }
    void write_single(const char& val) {
        if (pos == SIZE) flush();
        line[pos++] = val;
    }
    template <class T, enable_if_t<is_integral<T>::value>* = nullptr>
    void write_single(T val) {
        if (pos > (1 << 15) - 50) flush();
        if (val == 0) {
            write_single('0');
            return;
        }
        if (val < 0) {
            write_single('-');
            val = -val; // todo min
        }
        size_t len = 0;
        while (val) {
            small[len++] = char(0x30 | (val % 10));
            val /= 10;
        }
        for (size_t i = 0; i < len; i++) {
            line[pos + i] = small[len - 1 - i];
        }
        pos += len;
    }
    void write_single(__int128 val) {
        if (pos > (1 << 15) - 50) flush();
        if (val == 0) {
            write_single('0');
            return;
        }
        if (val < 0) {
            write_single('-');
            val = -val; // todo min
        }
        size_t len = 0;
        while (val) {
            small[len++] = char(0x30 | (val % 10));
            val /= 10;
        }
        for (size_t i = 0; i < len; i++) {
            line[pos + i] = small[len - 1 - i];
        }
        pos += len;
    }

    void write_single(const string& s) {
        for (char c : s) write_single(c);
    }
    void write_single(const char* s) {
        size_t len = strlen(s);
        for (size_t i = 0; i < len; i++) write_single(s[i]);
    }
    template <class T> void write_single(const V<T>& val) {
        auto n = val.size();
        for (size_t i = 0; i < n; i++) {
            if (i) write_single(' ');
            write_single(val[i]);
        }
    }
};
/**
 * Multiset(AA Tree)
 *
 * template引数のclass Dは要素の型、class Cは比較関数
 */
template<class D, class C = less<D>>
struct AAMSet {
    struct Node;
    using NP = Node*;
    static Node last_d;
    static NP last;
    struct Node {
        NP l, r;
        int level, sz;
        D v;
        Node(): l(nullptr), r(nullptr), level(0), sz(0) {}
        Node(D vv): l(last), r(last), level(1), sz(1) {
            v = vv;
        }
        /// メモリプールをしたい時のためにnewはラッパする
        static NP make() {
            return new Node();
        }
        static NP make(D vv) {
            return new Node(vv);
        }

        inline void update() {
            sz = 1+l->sz+r->sz;
        }

        inline void push() {
        }
    } *n;

    static D at(NP n, int k) {
        if (k == n->l->sz) return n->v;
        n->push();
        if (k < n->l->sz) {
            return at(n->l, k);
        } else {
            return at(n->r, k - (n->l->sz+1));
        }
    }
    /// k番目の要素を取得
    D at(int k) {
        return at(n, k);
    }
    static int lb(NP n, D x) {
        if (n == last) return 0;
        if (C()(n->v, x)) return n->l->sz + 1 + lb(n->r, x);
        return lb(n->l, x);
    }
    /// lower_bound、ただし返り値はインデックス
    int lb(D v) {
        return lb(n, v);
    }
    static int ub(NP n, D x) {
        if (n == last) return 0;
        if (C()(x, n->v)) return ub(n->l, x);
        return n->l->sz + 1 + ub(n->r, x);
    }
    /// upper_bound、ただし返り値はインデックス
    int ub(D v) {
        return ub(n, v);
    }
    static NP insert(NP n, D x) {
        if (n == last) {
            return Node::make(x);
        }
        n->push();
        if (!C()(n->v, x)) {
            n->l = insert(n->l, x);
            n->update();
        } else {
            n->r = insert(n->r, x);
            n->update();
        }
        n = skew(n);
        n = pull(n);
        return n;
    }
    /// xをinsertする
    void insert(D x) {
        n = insert(n, lb(x), x);
    }
    static NP erase(NP n, D x) {
        assert(n != last);
        n->push();
        if (!C()(n->v, x) && !C()(x, n->v)) {
            if (n->level == 1) {
                return n->r;
            }
            auto x = at0_with_remove(n->r);
            NP nn = x.first;
            nn->push();
            nn->l = n->l;
            nn->r = x.second;
            nn->level = n->level;
            nn->update();
            return rightdown(nn);
        }
        if (C()(x, n->v)) {
            n->l = erase(n->l, x);
            n->update();
            return leftdown(n);
        } else {
            n->r = erase(n->r, x);
            n->update();
            return rightdown(n);
        }
    }
    /// xを削除する
    void erase(D x) {
        n = remove(n, lb(x));
    }

    static void tp(NP n) {
        if (n == last) return;
        n->push();
        tp(n->l);
        cout << n->v << " ";
        tp(n->r);
    }
    void tp() {
        tp(n);
        printf("\n");
    }
    static void allpush(NP n) {
        if (n == last) return;
        n->push();
        allpush(n->l);
        allpush(n->r);
    }
    void allpush() {
        allpush(n);
        return;
    }

    static NP built(int sz, D d[]) {
        if (!sz) return last;
        int md = (sz-1)/2;
        NP n = Node::make(d[md]);
        n->l = built(md, d);
        n->r = built(sz-(md+1), d+(md+1));
        n->level = n->l->level+1;
        n->update();
        return n;
    }
    AAMSet() : n(last) {}
    AAMSet(NP n) : n(n) {}
    //木の初期化はn回insertより一気に作る方が有意に速くなる事が多い
    AAMSet(int sz, D d[]) {
        n = built(sz, d);
    }


    //基本動作
    int sz() {
        return n->sz;
    }
    int size() {
        return sz();
    }
    void merge(AAMSet r) {
        n = merge(n, r.n);
    }
    AAMSet split(int k) {
        auto y = split(n, k);
        n = y.first;
        return AAMSet(y.second);
    }
    void insert(int k, D x) {
        n = insert(n, k, x);
    }
    void remove(int k) {
        n = remove(n, k);
    }

    //AA木の基本動作であるskew/split splitは名前が紛らわしいためpullに変更してある
    static NP skew(NP n) {
        if (n->level == n->l->level) {
            NP L = n->l;
            n->push(); L->push();
            n->l = L->r;
            L->r = n;
            n->update(); L->update();
            return L;
        }
        return n;
    }
    static NP pull(NP n) {
        if (n->level == n->r->level && n->r->level == n->r->r->level) {
            NP R = n->r;
            n->push(); R->push();
            n->r = R->l;
            R->l = n;
            R->level++;
            n->update(); R->update();
            return R;
        }
        return n;
    }

    static NP leftdown(NP n) {
        assert(n->l->level < n->level);
        if (n->l->level == n->level-1) return n;
        n->level--;
        if (n->r->level == n->level) {
            n = pull(n);
        } else {
            n->r->level--;
            n->r = skew(n->r);
            n->r->r = skew(n->r->r);
            n = pull(n);
            n->r = pull(n->r);
        }
        return n;
    }
    static NP rightdown(NP n) {
        assert(n->r->level <= n->level);
        if (n->r->level >= n->level-1) return n;
        n->level--;
        n = skew(n);
        n->r = skew(n->r);
        n = pull(n);
        return n;
    }
    static NP superleftdown(NP n) {
        if (n->l->level == n->level-1) return n;
        if (n->level != n->r->level && n->r->level != n->r->r->level) {
            n->level--;
            return superleftdown(n);
        }
        n = leftdown(n);
        n->l = superleftdown(n->l);
        n = leftdown(n);
        return n;
    }
    static NP superrightdown(NP n) {
        if (n->r->level >= n->level-1) return n;
        n = rightdown(n);
        n->r = superrightdown(n->r);
        n = rightdown(n);
        return n;
    }

    static NP insert(NP n, int k, D x) {
        if (n == last) {
            assert(k == 0);
            return Node::make(x);
        }
        n->push();
        if (k <= n->l->sz) {
            n->l = insert(n->l, k, x);
            n->update();
        } else {
            n->r = insert(n->r, k - (n->l->sz+1), x);
            n->update();
        }
        n = skew(n);
        n = pull(n);
        return n;
    }
    //pair<0番目の要素,0番目の要素を削除した木>
    static pair<NP, NP> at0_with_remove(NP n) {
        n->push();
        if (n->l == last) {
            return {n, n->r};
        }
        auto x = at0_with_remove(n->l);
        n->l = x.second;
        n->update();
        x.second = leftdown(n);
        return x;
    }
    static NP remove(NP n, int k) {
        assert(n != last);
        n->push();
        if (k == n->l->sz) {
            if (n->level == 1) {
                return n->r;
            }
            auto x = at0_with_remove(n->r);
            NP nn = x.first;
            nn->push();
            nn->l = n->l;
            nn->r = x.second;
            nn->level = n->level;
            nn->update();
            return rightdown(nn);
        }
        if (k < n->l->sz) {
            n->l = remove(n->l, k);
            n->update();
            return leftdown(n);
        } else {
            n->r = remove(n->r, k - (n->l->sz+1));
            n->update();
            return rightdown(n);
        }
    }
    static NP merge(NP l, NP r) {
        if (l == last) return r;
        if (r == last) return l;
        if (l->level == r->level) {
            auto x = at0_with_remove(r);
            NP n = x.first;
            n->push();
            n->r = x.second;
            n->l = l;
            n->level = l->level+1;
            n->update();
            return rightdown(n);
        }
        NP n;
        l->push(); r->push();
        if (l->level > r->level) {
            l->push();
            l->r = merge(l->r, r);
            l->update();
            n = l;
        } else {
            r->push();
            r->l = merge(l, r->l);
            r->update();
            n = r;
        }
        n = skew(n);
        n = pull(n);
        return n;
    }
    static pair<NP, NP> split(NP n, int k) {
        if (n == last) return {last, last};
        n->push();
        if (k <= n->l->sz) {
            auto y = split(n->l, k);
            n->l = y.second;
            n->update();
            n = superleftdown(n);
            y.second = n;
            return y;
        } else {
            auto y = split(n->r, k- (n->l->sz+1));
            n->r = y.first;
            n->update();
            n = superrightdown(n);
            y.first = n;
            return y;
        }
    }
};
template<class D, class C>
typename AAMSet<D, C>::Node AAMSet<D, C>::last_d = AAMSet<D, C>::Node();
template<class D, class C>
typename AAMSet<D, C>::NP AAMSet<D, C>::last = &AAMSet<D, C>::last_d;
Scanner sc = Scanner(stdin);
Printer pr = Printer(stdout);

struct Node {
    using NP = Node*;
    NP l = nullptr, r = nullptr;
    int sz = 0;
    Node(int _sz) : sz(_sz) {
        if (sz == 1) return;
        l = new Node(sz / 2);
        r = new Node(sz - sz / 2);
    }
    AAMSet<ll> st0, st1;
    void add(int k, ll x, int f) {
        if (f == 0) st0.insert(x);
        else st1.insert(x);

        if (sz == 1) return;

        if (k < sz / 2) {
            l->add(k, x, f);
        } else {
            r->add(k - sz / 2, x, f);
        }
    }
    int get(int a, int b, ll up) {
        if (b <= 0 || sz <= a) return 0;
        if (a <= 0 && sz <= b) {
            return st0.ub(up) - st1.ub(up);
        }
        return l->get(a, b, up) + r->get(a - sz / 2, b - sz / 2, up);
    }
};

int main() {
    int n, q;
    sc.read(n, q);
    //assert(n <= 2000);

    struct E {
        int to;
        ll dist;
    };
    VV<E> g(n);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        ll w;
        sc.read(u, v, w);
        u--; v--;
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }
    V<int> lord(n), rord(n);
    int ord = 0;
    VV<int> par(16, V<int>(n, -1));
    VV<ll> dist(16, V<ll>(n, TEN(18)));
    V<ll> height(n);
    auto dfs = [&](auto self, int p, int b) -> void {
        lord[p] = ord++;
        for (auto e: g[p]) {
            int d = e.to;
            if (d == b) continue;
            par[0][d] = p;
            dist[0][d] = e.dist;
            height[d] = height[p] + e.dist;
            self(self, d, p);
        }
        rord[p] = ord;
    };
    dfs(dfs, 0, -1);
    for (int ph = 1; ph < 16; ph++) {
        for (int i = 0; i < n; i++) {
            int pp = par[ph - 1][i];
            par[ph][i] = (pp == -1 ? -1 : par[ph - 1][pp]);
            dist[ph][i] = (pp == -1 ? TEN(18) : dist[ph - 1][i] + dist[ph - 1][pp]);
        }
    }

    auto tr = new Node(n);
    //V<AAMSet<ll>> naive0(n);
    //V<AAMSet<ll>> naive1(n);
    for (int i = 0; i < q; i++) {
        int ty, u;
        ll t, l;
        sc.read(ty, u, t, l);
        u--;
        t += height[u];
        if (ty == 0) {
            int v = u;
            for (int ph = 15; ph >= 0; ph--) {
                if (dist[ph][v] <= l) {
                    l -= dist[ph][v];
                    v = par[ph][v];
                }
            }
            v = par[0][v];

            //naive0[lord[u]].insert(t);
            //if (v != -1) naive1[lord[v]].insert(t);

            tr->add(lord[u], t, 0);
            if (v != -1) tr->add(lord[v], t, 1);

        } else {
            /*int ans = 0;
            for (int j = lord[u]; j < rord[u]; j++) {
                ans += naive0[j].ub(t);
                ans -= naive1[j].ub(t);
            }*/

            int ans = tr->get(lord[u], rord[u], t);
            pr.writeln(ans);
        }
    }
    return 0;
}
0