Problem: 2018 January Platinum - Cow At Large
Code:
import java.io.*;
import java.util.*;
public class PlatinumAtLarge {
static int N;
static ArrayList<Integer>[] adj;
static int[] order;
static int[] sidx, eidx;
static int[] depth;
static int[] leafdist;
static int[] diff;
static int[] ans;
static Block[] blocks;
static int sqrt2N;
static int numBlocks;
public static void main(String[] args) throws IOException {
init_io();
processInput();
main();
out.close();
}
public static void main() {
sidx = new int[N]; eidx = new int[N];
order = new int[2*N]; depth = new int[N];
DFS(0, -1);
// Arrays.fill(sidx, -1);
// Arrays.fill(eidx, -1);
// for (int i = 0; i < 2*N; i++) {
// if (sidx[order[i]] == -1) {
// sidx[order[i]] = i;
// }
// eidx[order[i]] = i;
// }
diff = new int[N];
leafdist = new int[N];
// calculates leafdist
predfs(0, -1);
// accounts for leafdist from parent
predfs2(0, -1);
for (int i = 0; i < N; i++) {
diff[i] = depth[i]-leafdist[i];
}
// divide euler tour into blocks
sqrt2N = (int)Math.sqrt(2*N);
numBlocks = (2*N + sqrt2N - 1) / sqrt2N;
blocks = new Block[numBlocks];
for (int i = 0; i < numBlocks-1; i++) {
blocks[i] = new Block(i*sqrt2N, (i+1)*sqrt2N-1);
}
blocks[numBlocks-1] = new Block((numBlocks-1)*sqrt2N, 2*N-1);
ans = new int[N];
// out.println("Initial:");
// debug();
// calculate answer
dfs(0, -1);
for (int i = 0; i < N; i++) {
out.println(adj[i].size() == 1 ? 1 : ans[i]/2);
}
}
static int globalLazy = 0;
public static void dfs(int c, int p) {
if (p != -1) {
globalLazy++;
int s = sidx[c], e = eidx[c];
for (int i = 0; i < numBlocks; i++) {
if (blocks[i].L > e) break;
if (blocks[i].R < s) continue;
if (s <= blocks[i].L && blocks[i].R <= e) {
blocks[i].lazy -= 2;
}
else {
// must be partial coverage
blocks[i].update(s, e, -2);
}
}
}
// out.println("**************");
// out.printf("Pre: %d\n", c);
// debug();
int cans = 0;
for (int i = 0; i < numBlocks; i++) {
cans += blocks[i].query();
}
ans[c] = cans;
for (int to : adj[c]) {
if (to == p) continue;
dfs(to, c);
}
if (p != -1) {
// undo changes
globalLazy--;
int s = sidx[c], e = eidx[c];
for (int i = 0; i < numBlocks; i++) {
if (blocks[i].L > e) break;
if (blocks[i].R < s) continue;
if (s <= blocks[i].L && blocks[i].R <= e) {
blocks[i].lazy += 2;
}
else {
// must be partial coverage
blocks[i].update(s, e, 2);
}
}
}
// out.println("**************");
// out.printf("Post: %d\n", c);
// debug();
}
public static int predfs(int c, int p) {
int best = Integer.MAX_VALUE;
for (int to : adj[c]) {
if (to == p) continue;
best = Math.min(best, predfs(to, c)+1);
}
if (best == Integer.MAX_VALUE) best = 0;
leafdist[c] = best;
return best;
}
public static void predfs2(int c, int p) {
if (p != -1) {
leafdist[c] = Math.min(leafdist[c], leafdist[p]+1);
}
for (int to : adj[c]) {
if (to == p) continue;
predfs2(to, c);
}
}
static int cnt = 0;
public static void DFS(int c, int p) {
sidx[c] = cnt;
order[cnt++] = c;
for (int to : adj[c]) {
if (to != p) {
depth[to] = depth[c] + 1;
DFS(to, c);
}
}
eidx[c] = cnt;
order[cnt++] = c;
}
// public static void DFS() {
// depth = new int[N];
// order = new int[2*N];
// LinkedList<Integer> stack = new LinkedList<>();
// stack.push(0);
// int[] adjidx = new int[N];
// int[] par = new int[N];
// Arrays.fill(par, -1);
// int cnt = 0;
// while (stack.size() > 0) {
// int c = stack.peek();
// if (adjidx[c] == 0) {
// order[cnt++] = c;
// }
// if (adj[c].size() != adjidx[c]) {
// if (adj[c].get(adjidx[c]) == par[c])
// adjidx[c]++;
// if (adj[c].size() != adjidx[c]) {
// int next = adj[c].get(adjidx[c]);
// stack.push(next);
// par[next] = c;
// depth[next] = depth[c] + 1;
// adjidx[c]++;
// continue;
// }
// }
// order[cnt++] = c;
// stack.pop();
// }
// }
static StreamTokenizer in;
static PrintWriter out;
static BufferedReader br;
// I'm not sure why the official sol uses BIT
// I think prefix sum should suffice?
static class Block {
int lazy;
int L, R, len;
Node[] sortedarr; // sorted array of differences
int[] psum; // prefix sum
public Block(int lb, int rb) {
lazy = 0;
L = lb; R = rb; len = R-L+1;
sortedarr = new Node[R-L+1];
for (int i = L; i <= R; i++) {
sortedarr[i-L] = new Node(i, diff[order[i]], adj[order[i]].size());
}
Arrays.sort(sortedarr);
psum = new int[len];
psum[0] = sortedarr[0].val;
for (int i = 1; i < len; i++) {
psum[i] = sortedarr[i].val + psum[i-1];
}
}
public void update(int s, int e, int delta) {
for (int i = 0; i < len; i++) {
if (s <= sortedarr[i].idx && sortedarr[i].idx <= e) {
sortedarr[i].diff += delta;
}
}
Arrays.sort(sortedarr);
psum[0] = sortedarr[0].val;
for (int i = 1; i < len; i++) {
psum[i] = sortedarr[i].val + psum[i-1];
}
}
public int query() {
int min = -1; int max = len;
while(max-min > 1) {
int mid = (min + max)/2;
if (sortedarr[mid].diff+globalLazy+lazy >= 0) {
min = mid;
}
else {
max = mid;
}
}
if (min == -1) {
return 0;
}
else {
return psum[min];
}
}
}
public static void debug() {
for (int i = 0; i < numBlocks; i++) {
for (int j = 0; j < blocks[i].len; j++) {
out.printf("%d: %d\n", order[blocks[i].sortedarr[j].idx], blocks[i].sortedarr[j].diff+blocks[i].lazy);
}
}
}
public static void processInput() throws IOException {
N = nint();
adj = new ArrayList[N];
for (int i = 0; i < N; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < N-1; i++) {
int a = nint()-1, b = nint()-1;
adj[a].add(b); adj[b].add(a);
}
}
static class Node implements Comparable<Node> {
int idx, diff, val;
public Node(int i, int d, int deg) {
idx = i; diff = d; val = 2-deg;
}
public int compareTo(Node n) {
return n.diff - diff;
}
public String toString() {
return String.format("(%d, %d, %d)", idx, diff, val);
}
}
static int nint() throws IOException {
in.nextToken();
return (int) in.nval;
}
static void init_io() throws IOException {
// br = new BufferedReader(new InputStreamReader(System.in));
// br = new BufferedReader(new FileReader("./resource/atlarge.in"));
br = new BufferedReader(new FileReader("atlarge.in"));
in = new StreamTokenizer(br);
// out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
// out = new PrintWriter(new BufferedWriter(new FileWriter("./resource/atlarge.out")));
out = new PrintWriter(new BufferedWriter(new FileWriter("atlarge.out")));
}
}
Currently, my code TLEs on test case 3 and after, even though the complexity is N sqrt N log N. Also, I downloaded the test cases and found out that my code gives WA on 4 test cases out of test cases 7-11.
Can someone help me figure out why my code is slow and why it gives wrong answers on the later test cases?
Also, I read the official solution for this problem and saw that they used a BIT inside each block. Why does it use BIT? Wouldn’t prefix sums also work?