Sumdiv from Kilonova

Problem Info

Sumdiv

Question

For some reason my code seems to be TLE-ing only on test cases 9, 13, and 14 (out of 15).

What I’ve Tried

I tried everything from stress-testing to even comparing it with the usaco.guide editorial. I am using the same algorithms, so I don’t know what is wrong. I am pretty sure I am mod-ing to much. Going from using euler’s theorem to the extended euclidean algorithm for modular inversing didn’t work.

My Work

Here’s my submission.

#include <bits/stdc++.h>
using namespace std;

#define MOD 1'000'000'007

long long pow_modulo(long long base, long long power) {
    base %= MOD;

    if (base == 0) {
        return 1;
    }

    long long result = 1;
    long long multiplier = base;

    for (int bit = 0; bit <= log2(power); bit++) {
        if (power & (1LL << bit)) {
            result *= multiplier;
            result %= MOD;
        }

        multiplier *= multiplier;
        multiplier %= MOD;
    }

    return result;
}

pair<long long, long long> extended_euclidean_algorithm(long long a, long long b) {
    if (b == 0) {
        return make_pair(1, 0);
    }

    auto [x, y] = extended_euclidean_algorithm(b, a % b);
    return make_pair(y, x - a / b * y);
}

long long inverse(long long n) {
    auto [x, y] = extended_euclidean_algorithm(n, MOD);
    return (x % MOD + MOD) % MOD;
}

long long solve(long long base, long long exponent) {
    long long total = 1;
    int prime = 2;

    while (prime * prime <= base) {
        long long power = 0;

        while (base % prime == 0) {
            base /= prime;
            power++;
        }

        if (power == 0) {
            prime++;
            continue;
        }

        total *= (pow_modulo(pow_modulo(prime, power), exponent) * prime - 1) % MOD;
        total %= MOD;

        total *= inverse(prime - 1);
        total %= MOD;

        prime++;
    }

    if (base > 1) {
        if (base % MOD == 0) {
            total *= (exponent + 1) % MOD;
            total %= MOD;
        }

        else {
            total *= pow_modulo(base, exponent + 1) - 1;
            total %= MOD;

            total *= inverse(base - 1);
            total %= MOD;
        }
    }

    return total;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    freopen("sumdiv.in", "r", stdin);
    freopen("sumdiv.out", "w", stdout);

    long long base, exponent;
    cin >> base >> exponent;

    long long result = solve(base, exponent);
    cout << result << "\n";

    return 0;
}