My code for CSES Multiplication Table is exceeding the time limit, even though my solution seems to be the exact same as the official one (which runs in O(N\log N)). Could someone please look it over to see what’s wrong?
#include <iostream>
#include <bits/stdc++.h>
#include <array>
#include <fstream>
#include <string>
#include <algorithm>
#include <cmath>
#include <sstream>
using namespace std;
using ll=long long;
ll n;
bool ge(ll x){ //checks if x is >=median
ll T=0; //T=how many numbers in the grid are <=x
for (ll i=1; i<=n; i++) T+=min(n, x/i);
ll M=n*n/2+1;
if (T>=M) return true;
else return false;
}
ll firstTrue(ll lo, ll hi) { //find first x where x>=median
hi++;
while (lo < hi) {
int mid = lo + (hi - lo) / 2;
if (ge(mid)) {
hi = mid;
} else {
lo = mid + 1;
}
}
return lo;
}
int main() {
cin>>n;
cout<<firstTrue(1, n*n);
}
I tried in my IDE (replit), but the answer just doesn’t show up there. What I’m particularly confused about is if I just take out my functions and put into into the main loop, my code runs on time; however, when I do use functions, my code runs out of time.
So for instance, this runs out of time:
#include <bits/stdc++.h>
#include <array>
#include <fstream>
#include <string>
#include <algorithm>
#include <cmath>
#include <sstream>
using namespace std;
using ll=long long;
ll n;
//O(n) here
bool ge(ll x){ //checks if x is >=median
ll T=0; //T=how many numbers in the grid are <=x
for (ll i=1; i<=n; i++) T+=min(n, x/i);
ll M=n*n/2+1;
if (T>=M) return true;
else return false;
}
int main() {
cin>>n;
ll lo=1, hi=n*n;
while (lo < hi) {
int mid = lo + (hi - lo) / 2;
if (ge(mid)) {
hi = mid;
} else {
lo = mid + 1;
}
}
cout<<lo;
}
But this runs on time, even though I just put the function into the main loop:
#include <iostream>
#include <bits/stdc++.h>
#include <array>
#include <fstream>
#include <string>
#include <algorithm>
#include <cmath>
#include <sstream>
using namespace std;
using ll=long long;
ll n;
int main() {
cin>>n;
ll lo=1, hi=n*n;
hi++;
while (lo < hi) {
ll mid = lo + (hi - lo) / 2;
ll T=0; //T=how many numbers in the grid are <=x
for (ll i=1; i<=n; i++) T+=min(n, mid/i);
if (T>=(n*n+1)/2) {
hi = mid;
} else {
lo = mid + 1;
}
}
cout<< lo;
}