patternpythonMinor
Iterative Collatz with memoization
Viewed 0 times
withcollatzmemoizationiterative
Problem
I'm trying to write efficient code for calculating the chain-length of each number.
For example,
It took 13 iterations to get 13 down to 1, so
Here is my implementation:
It's quite efficient, but I'd like to know whether the same efficiency (or better) can be achieved by an iterative one.
For example,
13 -> 40 -> 20 -> 10 -> 5 -> 16 -> 8 -> 4 -> 2 -> 1.It took 13 iterations to get 13 down to 1, so
collatz(13) = 10.Here is my implementation:
def collatz(n, d={1:1}):
if not n in d:
d[n] = collatz(n * 3 + 1 if n & 1 else n/2) + 1
return d[n]It's quite efficient, but I'd like to know whether the same efficiency (or better) can be achieved by an iterative one.
Solution
Using the Python call stack to remember your state is very convenient and often results in shorter code, but it has a disadvantage:
You need to remember all the numbers you've visited along the way to your cache hit somehow. In your recursive implementation you remember them on the call stack. In an iterative implementation one would have to remember this stack of numbers in a list, perhaps like this:
This is a bit verbose, and it's slower than your recursive version:
But it's more robust:
>>> collatz(5**38)
Traceback (most recent call last):
File "", line 1, in
File "cr24195.py", line 3, in collatz
d[n] = collatz(n * 3 + 1 if n & 1 else n/2) + 1
[... many lines omitted ...]
File "cr24195.py", line 3, in collatz
d[n] = collatz(n * 3 + 1 if n & 1 else n/2) + 1
RuntimeError: maximum recursion depth exceededYou need to remember all the numbers you've visited along the way to your cache hit somehow. In your recursive implementation you remember them on the call stack. In an iterative implementation one would have to remember this stack of numbers in a list, perhaps like this:
def collatz2(n, d = {1: 1}):
"""Return one more than the number of steps that it takes to reach 1
starting from n by following the Collatz procedure.
"""
stack = []
while n not in d:
stack.append(n)
n = n * 3 + 1 if n & 1 else n // 2
c = d[n]
while stack:
c += 1
d[stack.pop()] = c
return c
This is a bit verbose, and it's slower than your recursive version:
>>> from timeit import timeit
>>> timeit(lambda:map(collatz, xrange(1, 10**6)), number=1)
2.7360708713531494
>>> timeit(lambda:map(collatz2, xrange(1, 10**6)), number=1)
3.7696099281311035But it's more robust:
>>> collatz2(5**38)
1002Code Snippets
>>> collatz(5**38)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "cr24195.py", line 3, in collatz
d[n] = collatz(n * 3 + 1 if n & 1 else n/2) + 1
[... many lines omitted ...]
File "cr24195.py", line 3, in collatz
d[n] = collatz(n * 3 + 1 if n & 1 else n/2) + 1
RuntimeError: maximum recursion depth exceeded>>> from timeit import timeit
>>> timeit(lambda:map(collatz, xrange(1, 10**6)), number=1)
2.7360708713531494
>>> timeit(lambda:map(collatz2, xrange(1, 10**6)), number=1)
3.7696099281311035>>> collatz2(5**38)
1002Context
StackExchange Code Review Q#24195, answer score: 5
Revisions (0)
No revisions yet.