結果

問題 No.399 動的な領主
ユーザー ゆうきゆうき
提出日時 2023-11-26 12:59:39
言語 Java21
(openjdk 21)
結果
AC  
実行時間 1,203 ms / 2,000 ms
コード長 14,531 bytes
コンパイル時間 4,074 ms
コンパイル使用メモリ 96,816 KB
実行使用メモリ 87,540 KB
最終ジャッジ日時 2024-09-26 11:40:16
合計ジャッジ時間 16,851 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 58 ms
37,180 KB
testcase_01 AC 59 ms
36,952 KB
testcase_02 AC 59 ms
37,068 KB
testcase_03 AC 61 ms
37,268 KB
testcase_04 AC 102 ms
38,808 KB
testcase_05 AC 235 ms
46,752 KB
testcase_06 AC 1,203 ms
87,540 KB
testcase_07 AC 1,153 ms
87,524 KB
testcase_08 AC 1,120 ms
86,996 KB
testcase_09 AC 1,074 ms
87,456 KB
testcase_10 AC 106 ms
38,840 KB
testcase_11 AC 216 ms
46,160 KB
testcase_12 AC 923 ms
86,212 KB
testcase_13 AC 885 ms
85,916 KB
testcase_14 AC 656 ms
84,976 KB
testcase_15 AC 662 ms
85,132 KB
testcase_16 AC 677 ms
84,656 KB
testcase_17 AC 1,097 ms
87,348 KB
testcase_18 AC 1,137 ms
87,440 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import static java.lang.Math.*;
import static java.util.Arrays.*;

import java.io.*;
import java.lang.reflect.Array;
import java.math.BigInteger;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.*;
import java.util.stream.IntStream;

class Solver{
  long st = System.currentTimeMillis();

  long elapsed(){ return System.currentTimeMillis() -st; }

  void reset(){ st = System.currentTimeMillis(); }

  final static int infI = (1 <<30) -1;
  final static long infL = 1L <<60;
  //  final static long mod = (int) 1e9 +7;
  final static long mod = 998244353;
  final static String yes = "Yes";
  final static String no = "No";

  static Random rd = ThreadLocalRandom.current();
  MyReader in = new MyReader(System.in);
  MyWriter out = new MyWriter(System.out);
  MyWriter log = new MyWriter(System.err){
    @Override
    void println(Object obj){ super.println(obj == null ? "null" : obj); };

    @Override
    protected void ln(){
      super.ln();
      flush();
    };
  };

  Object solve(){
    int N = in.it();
    DualSegmentTree<Long, Long> seg = new DualSegmentTree<>(N){

      @Override
      protected Long e(){ return 0L; }

      protected Long agg(Long v0,Long v1){ return v0 +v1; }

      @Override
      protected Long comp(Long f0,Long f1){ return f0 +f1; }

      @Override
      protected Long map(Long v,Long f,int l,int r){ return v +f; }
    };

    var hld = new HLD<Long, Long, Long>(N,seg);

    for (int i = 1;i < N;i++)
      hld.addEdge(in.idx(),in.idx());
    hld.makeTree(0);
    int Q = in.it();

    while (Q-- > 0) {
      int u = in.idx();
      int v = in.idx();
      hld.updPath(u,v,true,1L);
    }
    long ans = 0;
    for (int i = 0;i < N;i++) {
      Long t = hld.getNode(i);
      ans += t++ *t /2;
    }

    return ans;
  }

}

abstract class DualSegmentTree<V, F> extends Seg<V, F>{

  DualSegmentTree(int n){ super(n); }

  @Override
  protected abstract V map(V v,F f,int l,int r);

  @Override
  protected abstract F comp(F f0,F f1);

  @Override
  protected void rangeMap(int i){}

  @Override
  protected void upd(int i,F f){ upd(i,i +1,f); }

  @Override
  protected void upd(int l,int r,F f){
    down(l,r);
    super.upd(l,r,f);
  }

  @Override
  protected V get(int i){
    down(i,i +1);
    return super.get(i);
  }

}

abstract class LazySegmentTree<V, F> extends Seg<V, F>{

  LazySegmentTree(int n){ super(n); }

  @Override
  protected abstract V agg(V v0,V v1);

  @Override
  protected abstract V map(V v,F f,int l,int r);

  @Override
  protected abstract F comp(F f0,F f1);

  @Override
  protected void upd(int i,F f){ upd(i,i +1,f); }

  @Override
  protected void upd(int l,int r,F f){
    down(l,r);
    super.upd(l,r,f);
    up(l,r);
  }

  @Override
  protected V get(int i){ return get(i,i +1); }

  @Override
  protected V get(int l,int r){
    down(l,r);
    return super.get(l,r);
  }
}

class HLD<L, V, F> extends Graph<L>{

  Seg<V, F> lft;
  private Seg<V, F> rht;

  HLD(int n,Seg<V, F> seg){ this(n,seg,seg); }

  HLD(int n,Seg<V, F> lft,Seg<V, F> rht){
    super(n,false);
    this.lft = lft;
    this.rht = rht;
  }

  public void updNode(int u,F f){ upd(nds[u],f); }

  public void updEdge(int e,F f){ upd(es.get(e),f); }

  public void upd(Nd nd,F f){
    lft.upd(nd.l,f);
    if (lft != rht)
      rht.upd(nd.l,f);
  }

  public void updPath(int ui,int vi,boolean incLca,F f){
    Node u = nds[ui],v = nds[vi];
    while (true) {
      if (u.l > v.l) {
        var t = u;
        u = v;
        v = t;
      }

      var h = v.hp;
      if (h.l <= u.l) {
        lft.upd(u.l +(incLca ? 0 : 1),v.l +1,f);
        if (lft != rht)
          rht.upd(u.l +(incLca ? 0 : 1),v.l +1,f);
        return;
      }

      lft.upd(h.l,v.l +1,f);
      if (lft != rht)
        rht.upd(h.l,v.l +1,f);
      v = h.p;
    }
  }

  public V getPath(int ui,int vi,boolean incLca){
    Node u = nds[ui],v = nds[vi];
    V vl = lft.e(),vr = lft.e();
    boolean tog = false;
    while (true) {
      if (u.l > v.l) {
        var t = u;
        u = v;
        v = t;
        tog ^= true;
      }

      var h = v.hp;
      if (h.l <= u.l) {
        if (tog)
          vl = rht.agg(vl,lft.get(u.l +(incLca ? 0 : 1),v.l +1));
        else
          vr = rht.agg(rht.get(u.l +(incLca ? 0 : 1),v.l +1),vr);
        return rht.agg(vl,vr);
      }

      if (tog)
        vl = rht.agg(vl,lft.get(h.l,v.l +1));
      else
        vr = rht.agg(rht.get(h.l,v.l +1),vr);
      v = h.p;
    }
  }

  public V getSub(int ui,boolean incLca){ return lft.get(nds[ui].l +(incLca ? 0 : 1),nds[ui].r); }

  public V getNode(int ui){ return lft.get(nds[ui].l); }

  public V get(Node u){ return lft.get(u.l); }

  int lca(int ui,int vi){
    Node u = nds[ui],v = nds[vi];
    while (true) {
      if (u.l > v.l) {
        var t = u;
        u = v;
        v = t;
      }

      var h = v.hp;
      if (h.l <= u.l)
        return u.id;

      v = h.p;
    }
  }

  public void makeTree(int si){
    Stack<Node> stk = new Stack<>();
    var s = nds[si];
    s.p = s;
    stk.add(s);
    stk.add(s);
    while (!stk.isEmpty()) {
      var u = stk.pop();
      if (u.r < 1) {
        u.r = 1;
        for (var e:go(u.id)) {
          if (e.v == u.p)
            continue;
          es.set(e.id,e);
          e.v.p = u;
          stk.add(e.v);
          stk.add(e.v);
        }
      } else if (u != s)
        u.p.r += u.r;
    }

    for (var u:nds) {
      var go = go(u.id);
      for (int i = 1;i < go.size();i++)
        if (u.r < go.get(0).v.r || go.get(0).v.r < go.get(i).v.r && go.get(i).v.r < u.r)
          Collections.swap(go,0,i);
    }

    int hid = 0;
    stk.add(s);
    while (!stk.isEmpty()) {
      var u = stk.pop();
      u.r += u.l = hid;
      if (u.hp == null)
        u.hp = u;
      hid++;
      var go = go(u.id);
      for (int i = go.size();i-- > 0;) {
        var v = go.get(i).v;
        if (v == u.p)
          continue;
        if (i == 0)
          v.hp = u.hp;
        stk.add(v);
      }
    }

    for (var e:es) {
      e.l = e.v.l;
      e.r = e.v.r;
    }
  }

}

abstract class Seg<V, F> {
  protected int n,log;
  private V[] val;
  private F[] lazy;
  private int[] l,r;

  @SuppressWarnings("unchecked")
  Seg(int n){
    this.n = n;
    while (1 <<log <= n)
      log++;
    val = (V[]) new Object[n <<1];
    lazy = (F[]) new Object[n];
    l = new int[n <<1];
    r = new int[n <<1];

    for (int i = -1;++i < n;l[i +n] = i,r[i +n] = i +1)
      val[i +n] = init(i);
    for (int i = n;--i > 0;l[i] = l[i <<1],r[i] = r[i <<1 |1])
      merge(i);
  }

  protected abstract V e();

  protected V init(int i){ return e(); }

  protected V agg(V v0,V v1){ return null; }

  protected V map(V v,F f,int l,int r){ return null; }

  protected void rangeMap(int i){ val[i] = map(val[i],lazy[i],l[i],r[i]); }

  protected F comp(F f0,F f1){ return null; }

  protected V eval(int i){
    if (i < n && lazy[i] != null) {
      rangeMap(i);
      prop(i <<1,lazy[i]);
      prop(i <<1 |1,lazy[i]);
      lazy[i] = null;
    }
    return val[i];
  }

  private void merge(int i){ val[i] = agg(eval(i <<1),eval(i <<1 |1)); }

  protected void prop(int i,F f){
    if (i < n)
      lazy[i] = lazy[i] == null ? f : comp(lazy[i],f);
    else
      val[i] = map(val[i],f,l[i],r[i]);
  }

  protected void up(int l,int r){
    l = oddPart(l +n);
    r = oddPart(r +n);
    while (1 < l)
      merge(l >>= 1);
    while (1 < r)
      merge(r >>= 1);
  }

  protected void down(int l,int r){
    l = oddPart(l +n);
    r = oddPart(r +n);
    for (int i = log;i > 0;i--) {
      if (l >>i > 0)
        eval(l >>i);
      if (r >>i > 0)
        eval(r >>i);
    }
  }

  private int oddPart(int i){ return i /(i &-i); }

  protected void upd(int i,F f){ prop(i +n,f); }

  protected void upd(int l,int r,F f){
    l += n;
    r += n;
    do {
      if ((l &1) == 1)
        prop(l++,f);
      if ((r &1) == 1)
        prop(--r,f);
    } while ((l >>= 1) < (r >>= 1));
  }

  protected V get(int i){ return val[i +n]; }

  protected V get(int l,int r){
    l += n;
    r += n;
    V vl = e(),vr = e();
    while (l < r) {
      vl = (l &1) == 0 ? vl : agg(vl,eval(l++));
      vr = (r &1) == 0 ? vr : agg(eval(--r),vr);
      l >>= 1;
      r >>= 1;
    }

    return agg(vl,vr);
  }

}

abstract class SegmentTree<V, F> extends Seg<V, F>{

  SegmentTree(int n){ super(n); }

  @Override
  protected abstract V agg(V v0,V v1);

  @Override
  protected abstract V map(V v,F f,int l,int r);

  @Override
  protected void upd(int i,F f){
    super.upd(i,f);
    up(i,i +1);
  }

}

class Edge<L> extends Nd{
  Node u,v;
  L val;

  Edge(int id,Node u,Node v,L val){
    super(id);
    this.u = u;
    this.v = v;
    this.val = val;
  }

  @Override
  public String toString(){ return "(" +u.id +"," +v.id +")"; }
}

class Node extends Nd{
  Node p,hp;

  Node(int id){ super(id); }

  @Override
  public String toString(){ return "" +id; }
}

class Nd{
  int id,l,r;

  Nd(int id){ this.id = id; }
}

class Graph<L> {
  public int n;
  List<Edge<L>> es;
  Node[] nds;
  private List<List<Edge<L>>> go,back;

  public Graph(int n,boolean dir){
    this.n = n;
    nds = new Node[n];
    go = new ArrayList<>();
    back = dir ? new ArrayList<>() : go;
    for (int i = 0;i < n;i++) {
      nds[i] = new Node(i);
      go.add(new ArrayList<>());
      back.add(new ArrayList<>());
    }
    es = new ArrayList<>();
  }

  public void addEdge(int u,int v){ addEdge(u,v,null); }

  public void addEdge(int u,int v,L l){
    Edge<L> e = new Edge<>(es.size(),nds[u],nds[v],l);
    es.add(e);
    go.get(u).add(e);
    back.get(v).add(new Edge<>(e.id,e.v,e.u,e.val));
  }

  public List<Edge<L>> go(int ui){ return go.get(ui); }

  public List<Edge<L>> back(int ui){ return back.get(ui); }
}

class Util{
  static char[] arrC(int N,IntUnaryOperator f){
    char[] ret = new char[N];
    for (int i = 0;i < N;i++)
      ret[i] = (char) f.applyAsInt(i);
    return ret;
  }

  static int[] arrI(int N,IntUnaryOperator f){
    int[] ret = new int[N];
    setAll(ret,f);
    return ret;
  }

  static long[] arrL(int N,IntToLongFunction f){
    long[] ret = new long[N];
    setAll(ret,f);
    return ret;
  }

  static double[] arrD(int N,IntToDoubleFunction f){
    double[] ret = new double[N];
    setAll(ret,f);
    return ret;
  }

  static <T> T[] arr(T[] arr,IntFunction<T> f){
    setAll(arr,f);
    return arr;
  }

}

class MyReader{
  private byte[] buf = new byte[1 <<16];
  private int ptr = 0;
  private int tail = 0;
  private InputStream in;

  MyReader(InputStream in){ this.in = in; }

  private byte read(){
    if (ptr == tail)
      try {
        tail = in.read(buf);
        ptr = 0;
      } catch (IOException e) {}
    return buf[ptr++];
  }

  private boolean isPrintable(byte c){ return 32 < c && c < 127; }

  private byte nextPrintable(){
    byte ret = read();
    while (!isPrintable(ret))
      ret = read();
    return ret;
  }

  int it(){ return toIntExact(lg()); }

  int[] it(int N){ return Util.arrI(N,i -> it()); }

  int[][] it(int H,int W){ return Util.arr(new int[H][],i -> it(W)); }

  int idx(){ return it() -1; }

  int[] idx(int N){ return Util.arrI(N,i -> idx()); }

  int[][] idx(int H,int W){ return Util.arr(new int[H][],i -> idx(W)); }

  long lg(){
    byte i = nextPrintable();
    boolean negative = i == 45;
    long n = negative ? 0 : i -'0';
    while (isPrintable(i = read()))
      n = 10 *n +i -'0';
    return negative ? -n : n;
  }

  long[] lg(int N){ return Util.arrL(N,i -> lg()); }

  long[][] lg(int H,int W){ return Util.arr(new long[H][],i -> lg(W)); }

  double dbl(){ return Double.parseDouble(str()); }

  double[] dbl(int N){ return Util.arrD(N,i -> dbl()); }

  double[][] dbl(int H,int W){ return Util.arr(new double[H][],i -> dbl(W)); }

  char[] ch(){ return str().toCharArray(); }

  char[][] ch(int H){ return Util.arr(new char[H][],i -> ch()); }

  String line(){
    StringBuilder sb = new StringBuilder();

    for (byte c;(c = read()) != '\n';)
      sb.append((char) c);
    return sb.toString();
  }

  String str(){
    StringBuilder sb = new StringBuilder();
    sb.append((char) nextPrintable());

    for (byte c;isPrintable(c = read());)
      sb.append((char) c);
    return sb.toString();
  }

  String[] str(int N){ return Util.arr(new String[N],i -> str()); }

}

class MyWriter{
  OutputStream out;
  byte[] buf = new byte[1 <<16];
  byte[] ibuf = new byte[20];
  int tail = 0;

  MyWriter(OutputStream out){ this.out = out; }

  void flush(){
    try {
      out.write(buf,0,tail);
      tail = 0;
    } catch (IOException e) {
      e.printStackTrace();
    }
  }

  protected void ln(){ write((byte) '\n'); }

  private void write(byte b){
    buf[tail++] = b;
    if (tail == buf.length)
      flush();
  }

  private void write(byte[] b,int off,int len){
    for (int i = off;i < off +len;i++)
      write(b[i]);
  }

  private void write(long n){
    if (n < 0) {
      n = -n;
      write((byte) '-');
    }

    int i = ibuf.length;
    do {
      ibuf[--i] = (byte) (n %10 +'0');
      n /= 10;
    } while (n > 0);

    write(ibuf,i,ibuf.length -i);
  }

  private void print(Object obj){
    if (obj instanceof Boolean)
      print((boolean) obj ? Solver.yes : Solver.no);
    else if (obj instanceof Integer)
      write((int) obj);
    else if (obj instanceof Long)
      write((long) obj);
    else if (obj instanceof char[] cs)
      for (char b:cs)
        write((byte) b);
    else if (obj.getClass().isArray()) {
      int l = Array.getLength(obj);
      for (int i = 0;i < l;i++) {
        print(Array.get(obj,i));
        if (i +1 < l)
          write((byte) ' ');
      }
    } else
      print(Objects.toString(obj).toCharArray());
  }

  void println(Object obj){
    if (obj == null)
      return;

    if (obj instanceof Collection<?> co)
      for (Object e:co)
        println(e);
    else if (obj.getClass().isArray() && Array.getLength(obj) > 0 && Array.get(obj,0).getClass().isArray()) {
      int l = Array.getLength(obj);
      for (int i = 0;i < l;i++)
        println(Array.get(obj,i));
    } else {
      print(obj);
      ln();
    }
  }

  void printlns(Object... o){
    print(o);
    ln();
  }
}

class Main{
  public static void main(String[] args) throws Exception{
    new Solver(){
      public void exe(){
        out.println(solve());
        //        Optional.ofNullable(solve()).ifPresent(System.out::println);
        out.flush();
        log.println(elapsed());
      }
    }.exe();
  }
}
0