Python USACO Training 3.2 Sweet Butter

My python code TLE’s on this question. I use the optimized version of dijkstra’s algorithm (heap) but it’s too slow. I need to make it about 1.5-2x faster, any tips?
Here’s the code:

"""
ID: solasky1
LANG: PYTHON3
TASK: butter
"""
import collections
import heapq
import math

with open('butter.in') as f:
    C, V, E = map(int, f.readline().strip().split())
    cows, weights = [], collections.defaultdict(dict)
    cow_dists = []

    for _ in range(C):
        cows.append(int(f.readline().strip()))

    for _ in range(E):
        start, end, cost = map(int, f.readline().strip().split())
        weights[start][end] = cost
        weights[end][start] = cost

#dijkstra's algorithm
#calculates the shortest distance from every cow to every pasture

for cow in range(C):
    heap = [(0, cows[cow])]
    dist = {}
    
    while heap:
        time, start = heapq.heappop(heap)
        if start not in dist:
            dist[start] = time
            for end in weights[start]:
                heapq.heappush(heap, (dist[start] + weights[start][end], end))

    cow_dists.append(dist)

#calculates minimum distance that every cow has to travel to a pasture
res = math.inf
for vertex in range(1, V + 1):
    s = 0
    for d in cow_dists:
        s += d[vertex]
    if s < res:
        res = s

with open('butter.out', 'w') as f:
    f.write(f'{res}\n')

Here’s the screenshot of the error.

Some ideas I’ve thought of are:
Some cows are in the same pasture so I’m repeating calculations, however I couldn’t think of an efficient way to implement this and it probably won’t change much since there won’t be repeats of pastures that often.
Python’s min-heap is a binary-heap and pretty slow, maybe some other structure would be better?
Maybe I shouldn’t have a the cow_dists list?

Python is really slow, so it might not even be possible (or it might be, but it just might require some intense low-level tinkering) to get AC on this problem with it. I suggest switching to another language like Java or C++.

        if start not in dist:

check the time complexity of this.

dist is a dictionary, so the complexity is O(1).

I was able to get some help from another forum, AoPS. Here’s the improvements added:

  1. Changed the dist dictionary to a list and 0-indexed everything. This made it ~10% faster.
  2. Checked if end already had a value before heappushing it. This made it around ~40% faster.
  3. Since different cows could be in the same pasture, I added a memo that remembered all the pastures it already calculated and the value dijkstra’s returned. This made it around ~35% faster.

Overall, it was able to pass! Image
I won’t make more threads like this, I’ll try to learn Java but by then I guess I’ll just copy a code from the internet if my algorithm TLE because of Pythonl

Here’s the code:

"""
ID: solasky1
LANG: PYTHON3
TASK: butter
"""
import collections
import heapq
import math

with open('butter.in') as f:
    C, V, E = map(int, f.readline().strip().split())
    cows, weights = [], collections.defaultdict(dict)
    cow_dists = []
    memo = {}

    for _ in range(C):
        cows.append(int(f.readline().strip()) - 1)

    for _ in range(E):
        start, end, cost = map(int, f.readline().split())
        weights[start - 1][end - 1] = cost
        weights[end - 1][start - 1] = cost

#dijkstra's algorithm
#calculates the shortest distance from every cow to every pasture
for cow in range(C):
    if cows[cow] in memo:
        cow_dists.append(memo[cows[cow]])
    else:
        heap = [(0, cows[cow])]
        dist = [-1] * V
        
        while heap:
            time, start = heapq.heappop(heap)
            if dist[start] == -1:
                dist[start] = time
                for end in weights[start]:
                    if dist[end] == -1:
                        heapq.heappush(heap, (dist[start] + weights[start][end], end))

        cow_dists.append(dist)
        memo[cows[cow]] = dist

#calculates minimum distance that every cow has to travel to a pasture
res = math.inf
for vertex in range(V):
    s = 0
    for d in cow_dists:
        s += d[vertex]
    if s < res:
        res = s

with open('butter.out', 'w') as f:
    f.write(f'{res}\n')
1 Like