結果

問題 No.1002 Twotone
ユーザー uwiuwi
提出日時 2020-02-28 22:12:54
言語 Java21
(openjdk 21)
結果
AC  
実行時間 2,377 ms / 5,000 ms
コード長 9,318 bytes
コンパイル時間 4,383 ms
コンパイル使用メモリ 90,812 KB
実行使用メモリ 118,924 KB
最終ジャッジ日時 2024-10-13 17:49:36
合計ジャッジ時間 47,309 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 48 ms
37,008 KB
testcase_01 AC 48 ms
36,940 KB
testcase_02 AC 45 ms
36,796 KB
testcase_03 AC 1,618 ms
89,704 KB
testcase_04 AC 2,047 ms
95,392 KB
testcase_05 AC 1,809 ms
94,956 KB
testcase_06 AC 49 ms
36,684 KB
testcase_07 AC 1,382 ms
76,692 KB
testcase_08 AC 1,573 ms
88,832 KB
testcase_09 AC 1,665 ms
87,320 KB
testcase_10 AC 48 ms
37,232 KB
testcase_11 AC 1,790 ms
85,024 KB
testcase_12 AC 2,295 ms
91,728 KB
testcase_13 AC 2,300 ms
91,904 KB
testcase_14 AC 52 ms
36,648 KB
testcase_15 AC 1,594 ms
72,500 KB
testcase_16 AC 1,965 ms
87,724 KB
testcase_17 AC 1,912 ms
87,080 KB
testcase_18 AC 50 ms
37,108 KB
testcase_19 AC 2,246 ms
87,944 KB
testcase_20 AC 2,158 ms
91,288 KB
testcase_21 AC 2,244 ms
93,708 KB
testcase_22 AC 54 ms
36,816 KB
testcase_23 AC 1,645 ms
78,840 KB
testcase_24 AC 2,377 ms
93,356 KB
testcase_25 AC 2,061 ms
90,228 KB
testcase_26 AC 47 ms
36,808 KB
testcase_27 AC 616 ms
79,188 KB
testcase_28 AC 840 ms
90,316 KB
testcase_29 AC 707 ms
90,180 KB
testcase_30 AC 51 ms
36,820 KB
testcase_31 AC 676 ms
89,072 KB
testcase_32 AC 825 ms
88,048 KB
testcase_33 AC 745 ms
88,980 KB
testcase_34 AC 1,528 ms
118,924 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

package contest200228;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.util.Map;

// 21:48
public class E {
	InputStream is;
	PrintWriter out;
	String INPUT = "";
	
	void solve()
	{
		int n = ni(), K = ni();
		int[] from = new int[n - 1];
		int[] to = new int[n - 1];
		int[] cs = new int[n-1];
		for (int i = 0; i < n - 1; i++) {
			from[i] = ni() - 1;
			to[i] = ni() - 1;
			cs[i] = ni();
		}
		int[][][] g = packWU(n, from, to, cs);
		int[] cpar = buildCentroidTree(g);
		dfsTopCT(cpar, g);
		out.println(ans);
	}
	
	public static int[][][] packWU(int n, int[] from, int[] to, int[] w) {
		int[][][] g = new int[n][][];
		int[] p = new int[n];
		for (int f : from)
			p[f]++;
		for (int t : to)
			p[t]++;
		for (int i = 0; i < n; i++)
			g[i] = new int[p[i]][2];
		for (int i = 0; i < from.length; i++) {
			--p[from[i]];
			g[from[i]][p[from[i]]][0] = to[i];
			g[from[i]][p[from[i]]][1] = w[i];
			--p[to[i]];
			g[to[i]][p[to[i]]][0] = from[i];
			g[to[i]][p[to[i]]][1] = w[i];
		}
		return g;
	}

	
	
	static class Context
	{
		boolean[] seps; // is separator?
		long[] wt; // [ind] = w between sep and ind
//		long[] dt;
		int[] vs;
		int[][] vss; // [neckind][vertices]
		int[][] ctch;
		int[][][] g;
		int[] stack;
		int[] inds;
//		int[] neckind;
//		int[] cpar;
	}

	public void dfsTopCT(int[] cpar, int[][][] g) {
		int n = g.length;
		int ctroot = -1;
		for(int i = 0;i < n;i++)if(cpar[i] == -1)ctroot = i;
		
		Context cx = new Context();
//		cx.cpar = cpar;
		cx.seps = new boolean[n];
		cx.wt = new long[n];
//		cx.dt = new long[n];
		cx.vs = new int[n];
		cx.vss = new int[n][];
		cx.ctch = parentToChildren(cpar);
		cx.g = g;
		cx.stack = new int[n];
		cx.inds = new int[n];
//		cx.neckind = new int[n];
		dfs(ctroot, cx);
	}
	
	public static int[][] parentToChildren(int[] par)
	{
		int n = par.length;
		int[] ct = new int[n];
		for(int i = 0;i < n;i++){
			if(par[i] >= 0){
				ct[par[i]]++;
			}
		}
		int[][] g = new int[n][];
		for(int i = 0;i < n;i++){
			g[i] = new int[ct[i]];
		}
		for(int i = 0;i < n;i++){
			if(par[i] >= 0){
				g[par[i]][--ct[par[i]]] = i;
			}
		}
		
		return g;
	}

	
	private void dfs(int sep, Context cx)
	{
		cx.seps[sep] = true;
		int neckp = 0;
		for(int[] neck : cx.g[sep]){
			if(cx.seps[neck[0]])continue;
			
			int sp = 0;
			cx.inds[sp] = 0;
			cx.wt[neck[0]] = neck[1];
//			cx.dt[neck[0]] = neck[1];
			int vsp = 0;
			cx.stack[sp++] = neck[0];
			while(sp > 0){
				int cur = cx.stack[sp-1];
				if(cx.inds[sp-1] == 0){
					cx.vs[vsp++] = cur;
//					if(cpar[cur] == sep)cx.neckind[cur] = neckp;
				}
				if(cx.inds[sp-1] == cx.g[cur].length){
					sp--; 
					continue; 
				}
				int[] e = cx.g[cur][cx.inds[sp-1]++];
				if(!cx.seps[e[0]] && !(sp-2 >= 0 && e[0] == cx.stack[sp-2])){
					if((int)(cx.wt[cur]) == e[1] || cx.wt[cur]>>>32 == e[1]) {
						cx.wt[e[0]] = cx.wt[cur];
					}else if(cx.wt[cur] >= 1L<<32) {
						cx.wt[e[0]] = Long.MAX_VALUE;
					}else {
						cx.wt[e[0]] = Math.min(cx.wt[cur], e[1])<<32|Math.max(cx.wt[cur], e[1]);
					}
					cx.stack[sp] = e[0];
					cx.inds[sp] = 0;
					sp++;
				}
			}
			cx.vss[neckp] = Arrays.copyOf(cx.vs, vsp);
			neckp++;
		}
		
		process(sep, cx, Arrays.copyOf(cx.vss, neckp));
		
		for(int e : cx.ctch[sep])dfs(e, cx);
	}
	
	long ans = 0;
	
	private void process(int sep, Context cx, int[][] vss)
	{
		Map<Long, Integer> all = new HashMap<>();
		long one = 0;
		Map<Long, Integer> sides = new HashMap<>();
		for(int[] vs : vss) {
			for(int v : vs) {
				long w = cx.wt[v];
				if(w == Long.MAX_VALUE)continue;
				ans += match(all, w, one, sides);
			}
			for(int v : vs) {
				long w = cx.wt[v];
				if(w == Long.MAX_VALUE)continue;
				all.put(w, all.getOrDefault(w, 0) + 1);
				if(w >= 1L<<32) {
					sides.put(w>>>32, sides.getOrDefault(w>>>32, 0) + 1);
					sides.put((long)(int)w, sides.getOrDefault((long)(int)w, 0) + 1);
				}else {
					one++;
				}
			}
		}
		
		for(int val : all.values()) {
			ans += val;
		}
		ans -= one;
	}
	
	long match(Map<Long, Integer> all, long w, long one, Map<Long, Integer> sides)
	{
		assert w != Long.MAX_VALUE;
		assert w != 0;
//		tr("match", all, w, one, sides);
		long ret = 0;
		if(w >= 1L<<32) {
			ret += all.getOrDefault(w, 0);
			ret += all.getOrDefault(w>>>32, 0);
			ret += all.getOrDefault((long)(int)w, 0);
		}else {
			ret += one;
			ret -= all.getOrDefault(w, 0);
			ret += sides.getOrDefault(w, 0);
		}
		return ret;
	}

	public static int[] buildCentroidTree(int[][][] g) {
		int n = g.length;
		int[] ctpar = new int[n];
		Arrays.fill(ctpar, -1);
		buildCentroidTree(g, 0, new boolean[n], new int[n], new int[n], new int[n], ctpar);
		return ctpar;
	}
	
	private static int buildCentroidTree(int[][][] g, int root, boolean[] sed, int[] par, int[] ord, int[] des, int[] ctpar)
	{
		// parent and level-order
		ord[0] = root;
		par[root] = -1;
		int r = 1;
		for(int p = 0;p < r;p++) {
			int cur = ord[p];
			for(int[] nex : g[cur]){
				if(par[cur] != nex[0] && !sed[nex[0]]){
					ord[r++] = nex[0];
					par[nex[0]] = cur;
				}
			}
		}
		// if(r == 1)return;
		
		// DP and find a separator
		int sep = -1; // always exists
		outer:
		for(int i = r-1;i >= 0;i--){
			int cur = ord[i];
			des[cur] = 1;
			for(int[] e : g[cur]){
				if(par[cur] != e[0] && !sed[e[0]])des[cur] += des[e[0]];
			}
			if(r-des[cur] <= r/2){
				for(int[] e : g[cur]){
					if(par[cur] != e[0] && !sed[e[0]] && des[e[0]] >= r/2+1)continue outer;
				}
				sep = cur;
				break;
			}
		}
		
		sed[sep] = true;
		for(int[] e : g[sep]){
			if(!sed[e[0]])ctpar[buildCentroidTree(g, e[0], sed, par, ord, des, ctpar)] = sep;
		}
		return sep;
	}


	public static int[][] parents3(int[][] g, int root) {
		int n = g.length;
		int[] par = new int[n];
		Arrays.fill(par, -1);

		int[] depth = new int[n];
		depth[0] = 0;

		int[] q = new int[n];
		q[0] = root;
		for (int p = 0, r = 1; p < r; p++) {
			int cur = q[p];
			for (int nex : g[cur]) {
				if (par[cur] != nex) {
					q[r++] = nex;
					par[nex] = cur;
					depth[nex] = depth[cur] + 1;
				}
			}
		}
		return new int[][] { par, q, depth };
	}

	static int[][] packU(int n, int[] from, int[] to) {
		int[][] g = new int[n][];
		int[] p = new int[n];
		for (int f : from)
			p[f]++;
		for (int t : to)
			p[t]++;
		for (int i = 0; i < n; i++)
			g[i] = new int[p[i]];
		for (int i = 0; i < from.length; i++) {
			g[from[i]][--p[from[i]]] = to[i];
			g[to[i]][--p[to[i]]] = from[i];
		}
		return g;
	}

	
	void run() throws Exception
	{
		is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
		out = new PrintWriter(System.out);
		
		long s = System.currentTimeMillis();
		solve();
		out.flush();
		if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
//		Thread t = new Thread(null, null, "~", Runtime.getRuntime().maxMemory()){
//			@Override
//			public void run() {
//				long s = System.currentTimeMillis();
//				solve();
//				out.flush();
//				if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
//			}
//		};
//		t.start();
//		t.join();
	}
	
	public static void main(String[] args) throws Exception { new E().run(); }
	
	private byte[] inbuf = new byte[1024];
	public int lenbuf = 0, ptrbuf = 0;
	
	private int readByte()
	{
		if(lenbuf == -1)throw new InputMismatchException();
		if(ptrbuf >= lenbuf){
			ptrbuf = 0;
			try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
			if(lenbuf <= 0)return -1;
		}
		return inbuf[ptrbuf++];
	}
	
	private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
	private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
	
	private double nd() { return Double.parseDouble(ns()); }
	private char nc() { return (char)skip(); }
	
	private String ns()
	{
		int b = skip();
		StringBuilder sb = new StringBuilder();
		while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
			sb.appendCodePoint(b);
			b = readByte();
		}
		return sb.toString();
	}
	
	private char[] ns(int n)
	{
		char[] buf = new char[n];
		int b = skip(), p = 0;
		while(p < n && !(isSpaceChar(b))){
			buf[p++] = (char)b;
			b = readByte();
		}
		return n == p ? buf : Arrays.copyOf(buf, p);
	}
	
	private int[] na(int n)
	{
		int[] a = new int[n];
		for(int i = 0;i < n;i++)a[i] = ni();
		return a;
	}
	
	private long[] nal(int n)
	{
		long[] a = new long[n];
		for(int i = 0;i < n;i++)a[i] = nl();
		return a;
	}
	
	private char[][] nm(int n, int m) {
		char[][] map = new char[n][];
		for(int i = 0;i < n;i++)map[i] = ns(m);
		return map;
	}
	
	private int[][] nmi(int n, int m) {
		int[][] map = new int[n][];
		for(int i = 0;i < n;i++)map[i] = na(m);
		return map;
	}
	
	private int ni() { return (int)nl(); }
	
	private long nl()
	{
		long num = 0;
		int b;
		boolean minus = false;
		while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
		if(b == '-'){
			minus = true;
			b = readByte();
		}
		
		while(true){
			if(b >= '0' && b <= '9'){
				num = num * 10 + (b - '0');
			}else{
				return minus ? -num : num;
			}
			b = readByte();
		}
	}
	
	private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}
0