Help on USACO training 1.2 Arithmetic Progressions (ariprog) TLE

Arithmetic Progressions

Unfortunately, the USACO training pages require one to be signed in, so here is a description of the problem.

Description of Arithmetic Progressions

Arithmetic Progressions

An arithmetic progression is a sequence of the form a, a+b, a+2b, …, a+nb where n=0, 1, 2, 3, … . For this problem, a is a non-negative integer and b is a positive integer.

Write a program that finds all arithmetic progressions of length n in the set S of bisquares. The set of bisquares is defined as the set of all integers of the form p2 + q2 (where p and q are non-negative integers).

TIME LIMIT: 5 secs

PROGRAM NAME: ariprog

INPUT FORMAT

|Line 1:|N (3 <= N <= 25), the length of progressions for which to search|

|Line 2:|M (1 <= M <= 250), an upper bound to limit the search to the bisquares with 0 <= p,q <= M.|

SAMPLE INPUT (file ariprog.in)

5
7

OUTPUT FORMAT

If no sequence is found, a single line reading `NONE’. Otherwise, output one or more lines, each with two integers: the first element in a found sequence and the difference between consecutive elements in the same sequence. The lines should be ordered with smallest-difference sequences first and smallest starting number within those sequences first.

There will be no more than 10,000 sequences.

SAMPLE OUTPUT (file ariprog.out)

1 4
37 4
2 8
29 8
1 12
5 12
13 12
17 12
5 20
2 24

I’m using python, and whatever I do, I can only get past 7 of the 9 test cases.

My code finds all bisquares, then iterates through them and checks for possible arithmetic progressions. Finally, it outputs the arithmetic progessions.

Here is my code:

"""
ID: shawn.z2
LANG: PYTHON3
TASK: ariprog
"""
import sys

sys.stdin = open('ariprog.in', "r")
sys.stdout = open('ariprog.out', "w")

N = int(sys.stdin.readline())
M = int(sys.stdin.readline())
m2 = 2 * M * M
found_ariprogs = []
bisquares = set()
lookup_table = [False] * (m2 + 1)

# find bisquares
for i in range(M + 1):
    i2 = i ** 2
    for j in range(i, M + 1):
        bisquare = i2 + j ** 2
        bisquares.add(bisquare)
        lookup_table[bisquare] = True

bisquares = sorted(list(bisquares))


def is_ariprog(a, b):
    # checks each term in arithmetic progression
    for n in range(N):
        if not lookup_table[a + n * b]:
            return False
    return True


# iterate through bisquares and check for arithmetic progressions
l_bisquares = len(bisquares)
for i in range(l_bisquares):
    a = bisquares[i]
    lim = (m2 - a) // (N - 1)
    # for each bisquare where the difference between terms in the possible arithmetic progression doesn't make the last term in the arithmetic progression go over the last bisquare
    # test the arithmetic progression
    for j in range(i + 1, l_bisquares):
        b = bisquares[j] - a
        if b > lim:
            break
        if is_ariprog(a, b):
            found_ariprogs.append((b, a))

# output
found_ariprogs.sort()
if len(found_ariprogs) == 0:
    sys.stdout.write("NONE\n")
    sys.exit(0)
for ariprog in found_ariprogs:
    sys.stdout.write(f"{ariprog[1]} {ariprog[0]}\n")

The specific test case I’m stuck on is the 8th test case, where the input is:
22
250
and output is:
13421 2772

I always time out on this test case.
Can anyone help? Thanks!

I think it just might be because Python is too slow.

1 Like

:frowning: time to switch to java ig