I’m trying to solve the “Squirrel Cities” problem in the Link Cut Tree module, but could not pass all the test cases.
I thought the problem is straightforward. Basically, we can sort the durability from large to small. For each question, add all edges that are >= the required durability and find what’s the min time required to build a MST. Then any question which requires time <= this min time is YES. Otherwise NO.
To find this min time, each time when I add a qualified edge (satisfy the durability), say between point a and b
- if a and b are not connected, obviously add this edge
- if a and b are connected, let’s find whether there is an edge in the path from a to b where the time is > the time on this new edge. If so, replace it. This is done by using LCT, by
- set a as the root and access b, this basically form the splay tree for the path from a->b
- on this splay tree, we can easily find the edge which has the max time, very similar to how to calculate the # of nodes in a splay tree
This is how I calculate the max edge in the path
Updated calc() function in LCT structure
void calc() { // recalc vals
for(int i=0; i<2; i++) if (c[i]) c[i]->prop();
sz = 1+getSz(c[0])+getSz(c[1]);
sub = 1+getSub(c[0])+getSub(c[1])+vsub;
stsum = val + getStSum(c[0])+getStSum(c[1]); // update sum of splay tree value; same approach as sz
// find max edge
pathMax = 0;
if(c[0]) {
pathMaxLoc = mp(vtx, c[0]->vtx);
pathMax = eW[pathMaxLoc];
if(getPathMax(c[0]) > pathMax) {
pathMaxLoc = c[0]->pathMaxLoc;
pathMax = c[0]->pathMax;
}
}
if(c[1]) {
if(getPathMax(c[1]) > pathMax) {
pathMaxLoc = c[1]->pathMaxLoc;
pathMax = c[1]->pathMax;
}
if(eW[mp(vtx, c[1]->vtx)] > pathMax) {
pathMaxLoc = mp(vtx, c[1]->vtx);
pathMax = eW[pathMaxLoc];
}
}
}
and here is the logic to do this replacement
replace circle
if(connected(LCT[a], LCT[b])) {
/// remove heavier edge in the cycle
sn c = lca(LCT[a], LCT[b]);
c->makeRoot();
LCT[a]->access(); int w1 = LCT[a]->pathMax; pii p1 = LCT[a]->pathMaxLoc;
LCT[b]->access(); int w2 = LCT[b]->pathMax; pii p2 = LCT[b]->pathMaxLoc;
if(w1 > nt && w1 >= w2) {
eW.erase(mp(p1.ff, p1.ss)); eW.erase(mp(p1.ss,p1.ff)); cut(LCT[p1.ff], LCT[p1.ss]); usedT.erase(w1);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(nt);
maxT = *(usedT.rbegin());
vtx = a;
} else if(w2 > x.ff) {
eW.erase(mp(p2.ff, p2.ss)); eW.erase(mp(p2.ss,p2.ff)); cut(LCT[p2.ff], LCT[p2.ss]); usedT.erase(w2);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(x.ff);
maxT = *(usedT.rbegin());
vtx = a;
} else {
// a-b connection is useless
}
} else {
maxT = max(x.ff, maxT);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(x.ff);
vtx = a;
if(getSub(LCT[vtx]->getRoot()) == N) isMST = true;
}
However, I was not able to pass all tests (WA for some cases). I would assume my understanding of the splay tree and/or LCT is not completely correct. Can anybody help me by pointing out what’s wrong here?
Here is the complete code:
Summary
/*
ID: USACO_template
LANG: C++
PROG: https://dmoj.ca/problem/wac4p7#
*/
#include <iostream> //cin , cout
#include <fstream> //fin, fout
#include <stdio.h> // scanf , pringf
#include <cstdio>
#include <algorithm> // sort , stuff
#include <stack> // stacks
#include <queue> // queues
#include <map>
#include <string>
#include <string.h>
#include <set>
using namespace std;
typedef pair<int, int> pii;
typedef vector<int> vi; /// adjlist without weight
typedef vector<pii> vii; /// adjlist with weight
typedef vector<pair<int,pii>> vpip; /// edge with weight
typedef long long ll;
#define mp make_pair
#define ff first
#define ss second
#define pb push_back
#define sz(x) (int)(x).size()
const int MOD = 1e9+7; // 998244353;
const int MX = 2e5+5; //
const ll INF = 1e18; //
#define MAXV 100007
#define MAXE 300007
bool debug;
int N, M, Q;
map<pii, int> eW;
/// Link Cut Tree
/**
* Description: Link-Cut Tree. Given a function 1 ..N
* evaluates for any $a,b.$ \texttt{sz} is for path queries;
* \texttt{sub}, \texttt{vsub} are for subtree queries. \texttt{x->access()}
* brings \texttt{x} to the top and propagates it; its left subtree will be
* the path from \texttt{x} to the root and its right subtree will be empty.
* Then \texttt{sub} will be the number of nodes in the connected component
* of \texttt{x} and \texttt{vsub} will be the number of nodes under \texttt{x}.
* Use \texttt{makeRoot} for arbitrary path queries.
* Time: O(\log N)
* Usage: FOR(i,1,N+1)LCT[i]=new snode(i); link(LCT[1],LCT[2],1);
* Source: Dhruv Rohatgi, Eric Zhang
* https://sites.google.com/site/kc97ble/container/splay-tree/splaytree-cpp-3
* https://codeforces.com/blog/entry/67637
* https://codeforces.com/blog/entry/80383
* Verification: (see README for links)
* ekzhang Balanced Tokens
* Dynamic Tree Test (Easy)
* https://probgate.org/viewproblem.php?pid=578 (The Applicant)
*/
#include <assert.h> /* assert */
typedef struct snode* sn;
struct snode { //////// VARIABLES
sn p, c[2]; // parent, children
sn extra; // extra cycle node for "The Applicant"
bool flip = 0; // subtree flipped or not
int vtx;
ll val, sz; /// value in node, # nodes in current splay tree
int sub, vsub = 0; /// # of nodes in connected tree including itself; vsub stores sum of virtual children
/// where the "virtual subtrees" refers to the subtrees except the one in the Splay.
/// https://codeforces.com/blog/entry/67637
ll stsum; /// sum of all val in the splay tree; IF call access(), then it is from original root to this node
int pathMax; /// max weight of an edge along the splay tree path
pii pathMaxLoc;
snode(int _vtx, int _val) : vtx(_vtx), val(_val) {
p = c[0] = c[1] = extra = NULL; calc(); }
friend int getSz(sn x) { return x?x->sz:0; }
friend int getSub(sn x) { return x?x->sub:0; }
friend ll getStSum(sn x) { return x?x->stsum:0; }
friend ll getPathMax(sn x) { return x?x->pathMax:0; }
void prop() { // lazy prop
if (!flip) return;
swap(c[0],c[1]); flip = 0;
for(int i=0; i<2; i++) if (c[i]) c[i]->flip ^= 1;
}
void calc() { // recalc vals
for(int i=0; i<2; i++) if (c[i]) c[i]->prop();
sz = 1+getSz(c[0])+getSz(c[1]);
sub = 1+getSub(c[0])+getSub(c[1])+vsub;
stsum = val + getStSum(c[0])+getStSum(c[1]); // update sum of splay tree value; same approach as sz
// find max edge
pathMax = 0;
if(c[0]) {
pathMaxLoc = mp(vtx, c[0]->vtx);
pathMax = eW[pathMaxLoc];
if(getPathMax(c[0]) > pathMax) {
pathMaxLoc = c[0]->pathMaxLoc;
pathMax = c[0]->pathMax;
}
}
if(c[1]) {
if(getPathMax(c[1]) > pathMax) {
pathMaxLoc = c[1]->pathMaxLoc;
pathMax = c[1]->pathMax;
}
if(eW[mp(vtx, c[1]->vtx)] > pathMax) {
pathMaxLoc = mp(vtx, c[1]->vtx);
pathMax = eW[pathMaxLoc];
}
}
}
//////// SPLAY TREE OPERATIONS
int dir() {
if (!p) return -2;
for(int i=0; i<2; i++) if (p->c[i] == this) return i;
return -1; // p is path-parent pointer
} // -> not in current splay tree
// test if root of current splay tree
bool isRoot() { return dir() < 0; }
friend void setLink(sn x, sn y, int d) { /// x is parent in the original tree
if (y) y->p = x;
if (d >= 0) x->c[d] = y; }
void rot() { // assume p and p->p propagated
assert(!isRoot()); int x = dir(); sn pa = p;
setLink(pa->p, this, pa->dir());
setLink(pa, c[x^1], x); setLink(this, pa, x^1);
pa->calc();
}
void splay() { // bring this node to the root of splay tree
while (!isRoot() && !p->isRoot()) {
p->p->prop(), p->prop(), prop();
dir() == p->dir() ? p->rot() : rot();
rot();
}
if (!isRoot()) p->prop(), prop(), rot();
prop(); calc();
}
sn fbo(int b) { // find by order
prop(); int z = getSz(c[0]); // of splay tree
if (b == z) { splay(); return this; }
return b < z ? c[0]->fbo(b) : c[1] -> fbo(b-z-1);
}
//////// BASE OPERATIONS
/// make this node the "access node", i.e. the path from original root to this node is one splay tree
/// bring this to top of splay tree (not impacting the original representation tree)
void access() {
for (sn v = this, pre = NULL; v; v = v->p) {
v->splay(); // now switch virtual children
if (pre) v->vsub -= pre->sub;
if (v->c[1]) v->vsub += v->c[1]->sub;
v->c[1] = pre; v->calc(); pre = v;
}
splay(); assert(!c[1]); // right subtree is empty
}
void makeRoot() { // of the splay tree
access(); flip ^= 1; access(); assert(!c[0] && !c[1]); }
//////// QUERIES
friend sn lca(sn x, sn y) {
if (x == y) return x;
x->access(), y->access(); if (!x->p) return NULL;
x->splay(); return x->p?:x; // y was below x in latter case
} // access at y did not affect x -> not connected
friend bool connected(sn x, sn y) { return lca(x,y); }
// # nodes above; distance to root in original tree
int distRoot() { access(); return getSz(c[0]); }
sn getRoot() { /// get root of LCT component in the original tree
access(); sn a = this;
while (a->c[0]) a = a->c[0], a->prop();
a->access(); return a;
}
sn getPar(int b) { // get b-th parent on path to root
access(); b = getSz(c[0])-b; assert(b >= 0);
return fbo(b);
} // can also get min, max on path to root, etc
//////// MODIFICATIONS
void setVal(int v) { access(); val = v; calc(); }
void addVal(int v) { access(); val += v; calc(); }
friend void link(sn x, sn y, bool force = 1) {
assert(!connected(x,y));
if (force) y->makeRoot(); /// make x par of y; x -> y
else { y->access(); assert(!y->c[0]); }
x->access(); setLink(y,x,0); y->calc();
}
friend void cut(sn y) { // cut y from its parent
y->access(); assert(y->c[0]);
y->c[0]->p = NULL; y->c[0] = NULL; y->calc(); }
friend void cut(sn x, sn y) { // if x, y adj in tree
x->makeRoot(); y->access();
assert(y->c[0] == x && !x->c[0] && !x->c[1]); cut(y); }
};
sn LCT[MAXV];
struct DT {
int durability , time;
int idx;
bool operator<(DT other) const {
if(durability != other.durability ) return durability > other.durability ;
return time < other.time;
}
};
vector<DT> eDT, qDT;
pii e[MAXE];
map<int, int> eDlist, qDlist;
map<int, vii> eTlist, qTlist;
int ans[MAXE];
map<pair<int, pii>, int> alleW;
int main() {
debug = false;
ios_base::sync_with_stdio(false); cin.tie(0);
cin >> N >> M >> Q;
for(int i=1; i<=N; i++) {
int a, b, d, t; cin >> a >> b >> d >> t;
e[i] = mp(a, b);
auto x = mp(d, mp(a,b));
if(alleW.count(x)==0) {
alleW[x] = t;
x = mp(d,mp(b,a)); alleW[x] = t;
} else {
alleW[x] = min(t, alleW[x]);
x = mp(d,mp(b,a)); alleW[x] = min(t, alleW[x]);
}
eDT.pb({d, t, i});
eDlist[d]=1;
}
for(int i=1; i<=Q;i++) {
int d, t; cin >> d >> t;
qDT.pb({d, t, i});
qDlist[d]=1;
}
//sort(eDT.begin(), eDT.end());
//sort(qDT.begin(), qDT.end());
for(auto x : eDT) eTlist[x.durability].pb(mp(x.time, x.idx));
for(auto x : qDT) qTlist[x.durability].pb(mp(x.time, x.idx));
for(int i=1; i<=N; i++) LCT[i] = new snode(i, 1);
int maxT = 0;
multiset<int> usedT; usedT.clear();
int vtx = 0;
bool isMST = false;
for(auto qit = qTlist.rbegin(); qit!=qTlist.rend(); qit++) {
int qD = qit->ff;
if(debug) cout << "work on q durability " << qD << endl;
/// collect candidates into a queue
set<pii> eToAdd; eToAdd.clear();
while(!eTlist.empty() && eTlist.rbegin()->ff >= qD) {
if(debug) cout << " for edge durability of " << eTlist.rbegin()->ff << endl;
for(auto x : eTlist.rbegin()->ss) {
eToAdd.insert(x);
if(debug) cout << " ... collect t=" << x.ff << " from edge # " << x.ss << endl;
}
eTlist.erase(eTlist.rbegin()->ff);
}
/// add all qualified edge to the MST; if already connected, remove the larger time piece in the circle.
for(auto x : eToAdd) {
int nt = x.ff;
if(isMST && nt > maxT) continue;
int a = e[x.ss].ff, b = e[x.ss].ss;
if(debug) cout << " can we use edge " << a << " - " << b << endl;
if(connected(LCT[a], LCT[b])) {
/// remove heavier edge in the cycle
sn c = lca(LCT[a], LCT[b]);
c->makeRoot();
LCT[a]->access(); int w1 = LCT[a]->pathMax; pii p1 = LCT[a]->pathMaxLoc;
LCT[b]->access(); int w2 = LCT[b]->pathMax; pii p2 = LCT[b]->pathMaxLoc;
if(w1 > nt && w1 >= w2) {
eW.erase(mp(p1.ff, p1.ss)); eW.erase(mp(p1.ss,p1.ff)); cut(LCT[p1.ff], LCT[p1.ss]); usedT.erase(w1);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(nt);
maxT = *(usedT.rbegin());
vtx = a;
} else if(w2 > x.ff) {
eW.erase(mp(p2.ff, p2.ss)); eW.erase(mp(p2.ss,p2.ff)); cut(LCT[p2.ff], LCT[p2.ss]); usedT.erase(w2);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(x.ff);
maxT = *(usedT.rbegin());
vtx = a;
} else {
// a-b connection is useless
}
} else {
maxT = max(x.ff, maxT);
eW[mp(a,b)] = nt; eW[mp(b,a)] = nt; link(LCT[a], LCT[b]); usedT.insert(x.ff);
vtx = a;
if(getSub(LCT[vtx]->getRoot()) == N) isMST = true;
}
}
/// now find min T needed to make a MST. Aswer q
if(vtx == 0 || getSub(LCT[vtx]->getRoot()) != N) {
// not fully connected
isMST = false;
for(auto x : qit->ss) {
int j = x.ss;
ans[j] = 0;
}
} else {
isMST = true;
for(auto x : qit->ss) {
int j = x.ss;
ans[j] = (maxT <= x.ff)? 1 : 0;
}
}
}
for(int i=1; i<=Q; i++) {
cout << (ans[i] ? "YES" : "NO" ) << endl;
}
if(debug) cout << endl << "EOL" << endl;
}