HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Project Euler problem 92 - square digit chains

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
chainsproblemprojecteulersquaredigit

Problem

Link to problem.


A number chain is created by continuously adding the square of the
digits in a number to form a new number until it has been seen before.


For example,


44 -> 32 -> 13 -> 10 -> 1 -> 1


85 -> 89 -> 145 -> 42 -> 20 -> 4 -> 16 -> 37 -> 58 -> 89


Therefore any chain that arrives at 1 or 89 will become stuck in an
endless loop. What is most amazing is that EVERY starting number will
eventually arrive at 1 or 89.


How many starting numbers below ten million will arrive at 89?

This solution to Project Euler problem 92 takes about 80 seconds. How can I reduce the time to well under 60 seconds?

def squareDigits(num):
    total = 0
    for digit in str(num):
        total += int(digit) ** 2
    return total
pastChains = [None] * 10000001
pastChains[1], pastChains[89] = False, True
for num in range(2, 10000001):
    chain = [num]
    while pastChains[chain[-1]] is None:
        chain.append(squareDigits(chain[-1]))
    for term in chain:
        if pastChains[term] is not None:
            pastChains[term] = pastChains[chain[-1]]
print pastChains.count(True)

Solution


  1. Bug



There's a bug in your program. This line

if pastChains[term] is not None:


should be

if pastChains[term] is None:


  1. Piecewise improvement



In this section I'm going to show how you can speed up a program like this by making a series of small piecewise improvements. This doesn't often work for the Project Euler problems: normally you have to come up with new algorithmic ideas. But here you're pretty close to getting under the minute mark, and that makes piecewise improvement a plausible approach.

Note that I'm using Python 2.7 here, which is generally faster than Python 3.

2.1. Base case

I'm going to start by putting your code into a form which is more convenien for testing (also, where we can run it for smaller limits than ten million, which can be convenient when making frequent measurement). Also, I'll improve the coding style while I'm about it (add docstrings; put local variables in lower_case_with_underscores; use the more meaningful values 1 and 89 instead of the Booleans False and True).

def square_digits(n):
    """Return the sum of squares of the base-10 digits of n."""
    total = 0
    for digit in str(n):
        total += int(digit) ** 2
    return total

def problem92a(limit):
    """Return the count of starting numbers below limit that eventually arrive
    at 89, as a result of iterating the sum-of-squares-of-digits.

    """
    arrive = [None] * limit # Number eventually arrived at, or None if unknown.
    arrive[1], arrive[89] = 1, 89
    for n in range(2, limit):
        chain = [n]
        while arrive[chain[-1]] is None:
            chain.append(square_digits(chain[-1]))
        for term in chain:
            if arrive[term] is None:
                arrive[term] = arrive[chain[-1]]
    return arrive.count(89)

>>> from timeit import timeit
>>> timeit(lambda:problem92a(10**7), number=1)
185.13595986366272


That's a lot slower than your reported "80 seconds": clearly you have a much faster machine than I do!

2.2. Avoiding number/string conversion

Now, let's do some work on the square_digits function. As written, this function has to convert the integer n to a string and then convert each digit back to a number. We could avoid these conversions by working with numbers throughout:

def square_digits(n):
    """Return the sum of squares of the base-10 digits of n."""
    total = 0
    while n:
        total += (n % 10) ** 2
        n //= 10
    return total

>>> timeit(lambda:problem92a(10**7), number=1)
61.409788846969604


Nearly under the minute already!

2.3. Second version

Here are some obvious minor improvements:

-
Avoid the repeated lookup of chain[-1] by remembering the value in a local variable.

-
Reduce the size of chain by one (because the last element in the chain is remembered in the local variable).

-
The test arrive[term] is None is unnecessary: by this point in the code we know that only the last term in chain was found in arrive.

That yields the following code:

def problem92b(limit):
    """Return the count of starting numbers below limit that eventually arrive
    at 89, as a result of iterating the sum-of-squares-of-digits.

    """
    arrive = [None] * limit # Number eventually arrived at, or None if unknown.
    arrive[1], arrive[89] = 1, 89
    for n in range(2, limit):
        chain = []
        while not arrive[n]:
            chain.append(n)
            n = square_digits(n)
        dest = arrive[n]
        for term in chain:
            arrive[term] = dest
    return arrive.count(89)


When transforming code like this, it's always important to check that our transformations didn't break anything. Here's where the ability to run the program for small values of limit comes in handy:

>>> all(problem92a(i) == problem92b(i) for i in range(1000, 2000))
True


And this yields a further 12% speedup:

>>> timeit(lambda:problem92b(10**7), number=1)
53.771003007888794


2.3. Third version

The largest number under ten million is 9999999, whose sum of squares of digits is just 567. All other numbers in the range have even smaller sums of squares of digits. So for 568 and up, there is no need to follow the chain: we can just look up the answer directly. That suggests the following approach:

def problem92c(limit):
    """Return the count of starting numbers below limit that eventually arrive
    at 89, as a result of iterating the sum-of-squares-of-digits.

    """
    sum_limit = len(str(limit - 1)) * 9 ** 2 + 1
    arrive = [None] * sum_limit
    arrive[1], arrive[89] = 1, 89
    for n in range(2, sum_limit):
        chain = []
        while not arrive[n]:
            chain.append(n)
            n = square_digits(n)
        dest = arrive[n]
        for term in chain:
            arrive[term] = dest
    c = arrive.count(89)
    for n in range(sum_limit, limit):
        c += arrive[square_digits(n)] == 89
    return c


Again, we better check that we didn't break anyt

Code Snippets

if pastChains[term] is not None:
if pastChains[term] is None:
def square_digits(n):
    """Return the sum of squares of the base-10 digits of n."""
    total = 0
    for digit in str(n):
        total += int(digit) ** 2
    return total

def problem92a(limit):
    """Return the count of starting numbers below limit that eventually arrive
    at 89, as a result of iterating the sum-of-squares-of-digits.

    """
    arrive = [None] * limit # Number eventually arrived at, or None if unknown.
    arrive[1], arrive[89] = 1, 89
    for n in range(2, limit):
        chain = [n]
        while arrive[chain[-1]] is None:
            chain.append(square_digits(chain[-1]))
        for term in chain:
            if arrive[term] is None:
                arrive[term] = arrive[chain[-1]]
    return arrive.count(89)

>>> from timeit import timeit
>>> timeit(lambda:problem92a(10**7), number=1)
185.13595986366272
def square_digits(n):
    """Return the sum of squares of the base-10 digits of n."""
    total = 0
    while n:
        total += (n % 10) ** 2
        n //= 10
    return total

>>> timeit(lambda:problem92a(10**7), number=1)
61.409788846969604
def problem92b(limit):
    """Return the count of starting numbers below limit that eventually arrive
    at 89, as a result of iterating the sum-of-squares-of-digits.

    """
    arrive = [None] * limit # Number eventually arrived at, or None if unknown.
    arrive[1], arrive[89] = 1, 89
    for n in range(2, limit):
        chain = []
        while not arrive[n]:
            chain.append(n)
            n = square_digits(n)
        dest = arrive[n]
        for term in chain:
            arrive[term] = dest
    return arrive.count(89)

Context

StackExchange Code Review Q#31761, answer score: 4

Revisions (0)

No revisions yet.