patternpythonMinor
Project Euler problem 92 - square digit chains
Viewed 0 times
chainsproblemprojecteulersquaredigit
Problem
Link to problem.
A number chain is created by continuously adding the square of the
digits in a number to form a new number until it has been seen before.
For example,
44 -> 32 -> 13 -> 10 -> 1 -> 1
85 -> 89 -> 145 -> 42 -> 20 -> 4 -> 16 -> 37 -> 58 -> 89
Therefore any chain that arrives at 1 or 89 will become stuck in an
endless loop. What is most amazing is that EVERY starting number will
eventually arrive at 1 or 89.
How many starting numbers below ten million will arrive at 89?
This solution to Project Euler problem 92 takes about 80 seconds. How can I reduce the time to well under 60 seconds?
A number chain is created by continuously adding the square of the
digits in a number to form a new number until it has been seen before.
For example,
44 -> 32 -> 13 -> 10 -> 1 -> 1
85 -> 89 -> 145 -> 42 -> 20 -> 4 -> 16 -> 37 -> 58 -> 89
Therefore any chain that arrives at 1 or 89 will become stuck in an
endless loop. What is most amazing is that EVERY starting number will
eventually arrive at 1 or 89.
How many starting numbers below ten million will arrive at 89?
This solution to Project Euler problem 92 takes about 80 seconds. How can I reduce the time to well under 60 seconds?
def squareDigits(num):
total = 0
for digit in str(num):
total += int(digit) ** 2
return total
pastChains = [None] * 10000001
pastChains[1], pastChains[89] = False, True
for num in range(2, 10000001):
chain = [num]
while pastChains[chain[-1]] is None:
chain.append(squareDigits(chain[-1]))
for term in chain:
if pastChains[term] is not None:
pastChains[term] = pastChains[chain[-1]]
print pastChains.count(True)Solution
- Bug
There's a bug in your program. This line
if pastChains[term] is not None:should be
if pastChains[term] is None:- Piecewise improvement
In this section I'm going to show how you can speed up a program like this by making a series of small piecewise improvements. This doesn't often work for the Project Euler problems: normally you have to come up with new algorithmic ideas. But here you're pretty close to getting under the minute mark, and that makes piecewise improvement a plausible approach.
Note that I'm using Python 2.7 here, which is generally faster than Python 3.
2.1. Base case
I'm going to start by putting your code into a form which is more convenien for testing (also, where we can run it for smaller limits than ten million, which can be convenient when making frequent measurement). Also, I'll improve the coding style while I'm about it (add docstrings; put local variables in
lower_case_with_underscores; use the more meaningful values 1 and 89 instead of the Booleans False and True).def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
for digit in str(n):
total += int(digit) ** 2
return total
def problem92a(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = [n]
while arrive[chain[-1]] is None:
chain.append(square_digits(chain[-1]))
for term in chain:
if arrive[term] is None:
arrive[term] = arrive[chain[-1]]
return arrive.count(89)
>>> from timeit import timeit
>>> timeit(lambda:problem92a(10**7), number=1)
185.13595986366272That's a lot slower than your reported "80 seconds": clearly you have a much faster machine than I do!
2.2. Avoiding number/string conversion
Now, let's do some work on the
square_digits function. As written, this function has to convert the integer n to a string and then convert each digit back to a number. We could avoid these conversions by working with numbers throughout:def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
while n:
total += (n % 10) ** 2
n //= 10
return total
>>> timeit(lambda:problem92a(10**7), number=1)
61.409788846969604Nearly under the minute already!
2.3. Second version
Here are some obvious minor improvements:
-
Avoid the repeated lookup of
chain[-1] by remembering the value in a local variable.-
Reduce the size of
chain by one (because the last element in the chain is remembered in the local variable).-
The test
arrive[term] is None is unnecessary: by this point in the code we know that only the last term in chain was found in arrive.That yields the following code:
def problem92b(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
return arrive.count(89)When transforming code like this, it's always important to check that our transformations didn't break anything. Here's where the ability to run the program for small values of
limit comes in handy:>>> all(problem92a(i) == problem92b(i) for i in range(1000, 2000))
TrueAnd this yields a further 12% speedup:
>>> timeit(lambda:problem92b(10**7), number=1)
53.7710030078887942.3. Third version
The largest number under ten million is 9999999, whose sum of squares of digits is just 567. All other numbers in the range have even smaller sums of squares of digits. So for 568 and up, there is no need to follow the chain: we can just look up the answer directly. That suggests the following approach:
def problem92c(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
sum_limit = len(str(limit - 1)) * 9 ** 2 + 1
arrive = [None] * sum_limit
arrive[1], arrive[89] = 1, 89
for n in range(2, sum_limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
c = arrive.count(89)
for n in range(sum_limit, limit):
c += arrive[square_digits(n)] == 89
return cAgain, we better check that we didn't break anyt
Code Snippets
if pastChains[term] is not None:if pastChains[term] is None:def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
for digit in str(n):
total += int(digit) ** 2
return total
def problem92a(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = [n]
while arrive[chain[-1]] is None:
chain.append(square_digits(chain[-1]))
for term in chain:
if arrive[term] is None:
arrive[term] = arrive[chain[-1]]
return arrive.count(89)
>>> from timeit import timeit
>>> timeit(lambda:problem92a(10**7), number=1)
185.13595986366272def square_digits(n):
"""Return the sum of squares of the base-10 digits of n."""
total = 0
while n:
total += (n % 10) ** 2
n //= 10
return total
>>> timeit(lambda:problem92a(10**7), number=1)
61.409788846969604def problem92b(limit):
"""Return the count of starting numbers below limit that eventually arrive
at 89, as a result of iterating the sum-of-squares-of-digits.
"""
arrive = [None] * limit # Number eventually arrived at, or None if unknown.
arrive[1], arrive[89] = 1, 89
for n in range(2, limit):
chain = []
while not arrive[n]:
chain.append(n)
n = square_digits(n)
dest = arrive[n]
for term in chain:
arrive[term] = dest
return arrive.count(89)Context
StackExchange Code Review Q#31761, answer score: 4
Revisions (0)
No revisions yet.