Hello,
Can someone please help me understand the internal BIT solution to this problem? I know it’s a simple solution, but I don’t fully understand the template. I would really appreciate it if someone could walk me through the code briefly.
Thanks in advance
typedef long long ll;
const ll MX = 2 * 1e6;
constexpr int bits(int x) { // assert(x >= 0); // make C++11 compatible until USACO updates ...
return x == 0 ? 0 : 31-__builtin_clz(x); } // floor(log2(x))
template <class T, int ...Ns> struct BIT {
T val = 0; void upd(T v) { val += v; }
T query() { return val; }
};
template <class T, int N, int... Ns> struct BIT<T, N, Ns...> {
BIT<T,Ns...> bit[N+1];
template<typename... Args> void upd(int pos, Args... args) { assert(pos > 0);
for (; pos<=N; pos+=pos&-pos) bit[pos].upd(args...); }
template<typename... Args> T sum(int r, Args... args) {
T res=0; for (;r;r-=r&-r) res += bit[r].query(args...);
return res; }
template<typename... Args> T query(int l, int r, Args...
args) { return sum(r,args...)-sum(l-1,args...); }
};
template<class T, int N> int get_kth(const BIT<T,N>& bit, T des) {
assert(des > 0);
int ind = 0;
for (int i = 1<<bits(N); i; i /= 2)
if (ind+i <= N && bit.bit[ind+i].val < des)
des -= bit.bit[ind += i].val;
assert(ind < N); return ind+1;
}
BIT<int, MX> bit;
int main()
{
int n;
cin >> n;
vector<ll> a(n);
for(auto& it : a) cin >> it;
vector<int> p(n);
for(auto& it : p) cin >> it;
for (int i = 1; i<= n; i++){
bit.upd(i, 1);
}
for (auto i : p){
int index = get_kth(bit, i);
cout << a[index-1] << " ";
bit.upd(index, -1);
}
return 0;
}