TLE and WA on USACO 2018 January Platinum Cow at Large

Problem: 2018 January Platinum - Cow At Large

Code:

import java.io.*;
import java.util.*;

public class PlatinumAtLarge {
	static int N;
	static ArrayList<Integer>[] adj;
	static int[] order;
	static int[] sidx, eidx;
	static int[] depth;
	static int[] leafdist;
	static int[] diff;
	static int[] ans;
	static Block[] blocks;
	static int sqrt2N;
	static int numBlocks;
	public static void main(String[] args) throws IOException {
		init_io();
		processInput();
		main();
		out.close();
	}
	public static void main() {
		sidx = new int[N]; eidx = new int[N];
		order = new int[2*N]; depth = new int[N];
		DFS(0, -1);
//		Arrays.fill(sidx, -1);
//		Arrays.fill(eidx, -1);
//		for (int i = 0; i < 2*N; i++) {
//			if (sidx[order[i]] == -1) {
//				sidx[order[i]] = i;
//			}
//			eidx[order[i]] = i;
//		}
		diff = new int[N];
		leafdist = new int[N];
		// calculates leafdist
		predfs(0, -1);
		// accounts for leafdist from parent
		predfs2(0, -1);
		for (int i = 0; i < N; i++) {
			diff[i] = depth[i]-leafdist[i];
		}
		// divide euler tour into blocks
		sqrt2N = (int)Math.sqrt(2*N);
		numBlocks = (2*N + sqrt2N - 1) / sqrt2N;
		blocks = new Block[numBlocks];
		for (int i = 0; i < numBlocks-1; i++) {
			blocks[i] = new Block(i*sqrt2N, (i+1)*sqrt2N-1);
		}
		blocks[numBlocks-1] = new Block((numBlocks-1)*sqrt2N, 2*N-1);
		ans = new int[N];
//		out.println("Initial:");
//		debug();
		// calculate answer
		dfs(0, -1);
		for (int i = 0; i < N; i++) {
			out.println(adj[i].size() == 1 ? 1 : ans[i]/2);
		}
	}
	static int globalLazy = 0;
	public static void dfs(int c, int p) {
		if (p != -1) {
			globalLazy++;
			int s = sidx[c], e = eidx[c];
			for (int i = 0; i < numBlocks; i++) {
				if (blocks[i].L > e) break;
				if (blocks[i].R < s) continue;
				if (s <= blocks[i].L && blocks[i].R <= e) {
					blocks[i].lazy -= 2;
				}
				else {
					// must be partial coverage
					blocks[i].update(s, e, -2);
				}
			}
		}
//		out.println("**************");
//		out.printf("Pre: %d\n", c);
//		debug();
		int cans = 0;
		for (int i = 0; i < numBlocks; i++) {
			cans += blocks[i].query();
		}
		ans[c] = cans;
		for (int to : adj[c]) {
			if (to == p) continue;
			dfs(to, c);
		}
		if (p != -1) {
			// undo changes
			globalLazy--;
			int s = sidx[c], e = eidx[c];
			for (int i = 0; i < numBlocks; i++) {
				if (blocks[i].L > e) break;
				if (blocks[i].R < s) continue;
				if (s <= blocks[i].L && blocks[i].R <= e) {
					blocks[i].lazy += 2;
				}
				else {
					// must be partial coverage
					blocks[i].update(s, e, 2);
				}
			}
		}
//		out.println("**************");
//		out.printf("Post: %d\n", c);
//		debug();
	}
	public static int predfs(int c, int p) {
		int best = Integer.MAX_VALUE;
		for (int to : adj[c]) {
			if (to == p) continue;
			best = Math.min(best, predfs(to, c)+1);
		}
		if (best == Integer.MAX_VALUE) best = 0;
		leafdist[c] = best;
		return best;
	}
	public static void predfs2(int c, int p) {
		if (p != -1) {
			leafdist[c] = Math.min(leafdist[c], leafdist[p]+1);
		}
		for (int to : adj[c]) {
			if (to == p) continue;
			predfs2(to, c);
		}
	}
	static int cnt = 0;
	public static void DFS(int c, int p) {
		sidx[c] = cnt;
		order[cnt++] = c;
		for (int to : adj[c]) {
			if (to != p) {
				depth[to] = depth[c] + 1;
				DFS(to, c);
			}
		}
		eidx[c] = cnt;
		order[cnt++] = c;
	}
//	public static void DFS() {
//		depth = new int[N];
//		order = new int[2*N];
//		LinkedList<Integer> stack = new LinkedList<>();
//		stack.push(0);
//		int[] adjidx = new int[N];
//		int[] par = new int[N];
//		Arrays.fill(par, -1);
//		int cnt = 0;
//		while (stack.size() > 0) {
//			int c = stack.peek();
//			if (adjidx[c] == 0) {
//				order[cnt++] = c;
//			}
//			if (adj[c].size() != adjidx[c]) {
//				if (adj[c].get(adjidx[c]) == par[c])
//					adjidx[c]++;
//				if (adj[c].size() != adjidx[c]) {
//					int next = adj[c].get(adjidx[c]);
//					stack.push(next);
//					par[next] = c;
//					depth[next] = depth[c] + 1;
//					adjidx[c]++;
//					continue;
//				}
//			}
//			order[cnt++] = c;
//			stack.pop();
//		}
//	}
	static StreamTokenizer in;
	static PrintWriter out;
	static BufferedReader br;
	// I'm not sure why the official sol uses BIT
	// I think prefix sum should suffice?
	static class Block {
		int lazy;
		int L, R, len;
		Node[] sortedarr; // sorted array of differences
		int[] psum; // prefix sum
		public Block(int lb, int rb) {
			lazy = 0;
			L = lb; R = rb; len = R-L+1;
			sortedarr = new Node[R-L+1];
			for (int i = L; i <= R; i++) {
				sortedarr[i-L] = new Node(i, diff[order[i]], adj[order[i]].size());
			}
			Arrays.sort(sortedarr);
			psum = new int[len];
			psum[0] = sortedarr[0].val;
			for (int i = 1; i < len; i++) {
				psum[i] = sortedarr[i].val + psum[i-1];
			}
		}
		public void update(int s, int e, int delta) {
			for (int i = 0; i < len; i++) {
				if (s <= sortedarr[i].idx && sortedarr[i].idx <= e) {
					sortedarr[i].diff += delta;
				}
			}
			Arrays.sort(sortedarr);
			psum[0] = sortedarr[0].val;
			for (int i = 1; i < len; i++) {
				psum[i] = sortedarr[i].val + psum[i-1];
			}
		}
		public int query() {
			int min = -1; int max = len;
			while(max-min > 1) {
				int mid = (min + max)/2;
				if (sortedarr[mid].diff+globalLazy+lazy >= 0) {
					min = mid;
				}
				else {
					max = mid;
				}
			}
			if (min == -1) {
				return 0;
			}
			else {
				return psum[min];
			}
		}
	}
	public static void debug() {
		for (int i = 0; i < numBlocks; i++) {
			for (int j = 0; j < blocks[i].len; j++) {
				out.printf("%d: %d\n", order[blocks[i].sortedarr[j].idx], blocks[i].sortedarr[j].diff+blocks[i].lazy);
			}
		}
	}
	public static void processInput() throws IOException {
		N = nint();
		adj = new ArrayList[N];
		for (int i = 0; i < N; i++) adj[i] = new ArrayList<>();
		for (int i = 0; i < N-1; i++) {
			int a = nint()-1, b = nint()-1;
			adj[a].add(b); adj[b].add(a);
		}
	}
	static class Node implements Comparable<Node> {
		int idx, diff, val;
		public Node(int i, int d, int deg) {
			idx = i; diff = d; val = 2-deg;
		}
		public int compareTo(Node n) {
			return n.diff - diff;
		}
		public String toString() {
			return String.format("(%d, %d, %d)", idx, diff, val);
		}
	}

	static int nint() throws IOException {
		in.nextToken();
		return (int) in.nval;
	}

	static void init_io() throws IOException {
//		br = new BufferedReader(new InputStreamReader(System.in));
//		br = new BufferedReader(new FileReader("./resource/atlarge.in"));
		br = new BufferedReader(new FileReader("atlarge.in"));
		in = new StreamTokenizer(br);
//		out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
//		out = new PrintWriter(new BufferedWriter(new FileWriter("./resource/atlarge.out")));
		out = new PrintWriter(new BufferedWriter(new FileWriter("atlarge.out")));
	}
}

Currently, my code TLEs on test case 3 and after, even though the complexity is N sqrt N log N. Also, I downloaded the test cases and found out that my code gives WA on 4 test cases out of test cases 7-11.

Can someone help me figure out why my code is slow and why it gives wrong answers on the later test cases?

Also, I read the official solution for this problem and saw that they used a BIT inside each block. Why does it use BIT? Wouldn’t prefix sums also work?

Sorting after each update causes the algorithm to be O(logn sqrtn) per update instead of just O(sqrtn + logn) per update with a BIT. Hard to say where exactly the bug might be but I would watch out for boundary errors (does the last block go past n?) or off by one errors.

Wait, isn’t it also O(sqrt N log N) per update with a BIT? Because the solution constructs the BIT by doing an update for each value between x and y (sqrt N values at max), and the update function is O(log N).

Also, I think O(log N sqrt N) per update should be fast enough because there are at most 2 blocks being updated each round, and there are N rounds, so it’s O(N sqrt N log N) overall.

Okay, I’ll check for boundary errors.

Oh, hm, you’re right, didn’t think about it carefully enough. Maybe sorting classes in Java is a bit slow? Theoretically with some work you can perform the update in O(sqrtN) instead of O(sqrtN logN) by merging instead of resorting the whole array.

Okay, I managed to fix the wrong answer problem (turns out I forgot to check if the root node is a leaf)

However, I’m still getting TLE on test case 3 and after even after doing merge sort to lower the update complexity to O(sqrt N).

Are there any other optimizations that I can make?

I think it’s theoretically possible to remove a logn factor from the query() function by updating the answer per block by moving a pointer? But it’s non-trivial, maybe only the centroid decomp solution is viable in Java?

Can you elaborate on how to use a pointer to remove log N from query?

Not sure if it works, but maybe you could have pointers for where the binary search should end up and advance them.

I think the binary search is slower than a BIT because it requires branching (each branch misprediction is 10+ cycles). Same with sorting.

I would say there’s not much more to be learned from this particular approach, I would recommend learning the centroid decomp (which is a more general technique for tree problems) and then moving on.

Ah okay.

I managed to get it to pass by removing objects, doing a ton of optimizations, and increasing block size.

Thanks for your help!