patternpythonCritical
100 gunmen in a circle kill next person
Viewed 0 times
circlekillpersongunmennext100
Problem
I am very happy because I solved this problem with very little code:
"""
100 people are standing in a circle with gun in their hands.
1 kills 2, 3 kills 4, 5 kills 6 and so on till we are
left with only one person. Who will be the last person alive?
Write code to implement this ##efficiently.## 1:
for index, person in enumerate(persons):
del persons[(index + 1) % len(persons)]
print(persons)Solution
- Encapsulate
Writing code at the top level of a module makes it hard to test the code and hard to measure its performance. It's best to encapsulate code in a function. Accordingly, I'd write:
def survivor(n):
"""Return the survivor of a circular firing squad of n people."""
persons = list(range(1, n + 1))
while len(persons) > 1:
for index, _ in enumerate(persons):
del persons[(index + 1) % len(persons)]
return persons[0](The variable
person is not used in the body of the for loop; it's convential to write _ for such variables, and that's what I've done here.)- The meaning of efficiency
Is this code "efficient" as the question asks? Normally in computing we use "efficient" to mean algorithmic efficiency: that is, the rate at which the resources used by the program grow as a function of the input, usually expressed in big-O notation.
The question says, "Python is not efficient", but according to the usual view of efficiency, the programming language does not matter: efficiency is a property of the algorithm, not of the language it's implemented in.
- It's accidentally quadratic
What's the runtime of the
survivor function, expressed as a function of \$ n \$? Well, looking at the time complexity page on the Python Wiki, we can see that the del operation on a list takes \$ O(n) \$ where \$ n \$ is the length of the list, and this is executed once for each person who is killed, resulting in an overall runtime of \$ O(n^2) \$.It's possible to check this experimentally:
>>> t = 1
>>> for i in range(8, 17):
... t, u = timeit(lambda:survivor(2**i), number=1), t
... print('{:6d} {:.6f} {:.2f}'.format(2**i, t, t / u))
...
256 0.000138 0.00
512 0.000318 2.31
1024 0.000560 1.76
2048 0.001363 2.43
4096 0.006631 4.87
8192 0.030330 4.57
16384 0.132857 4.38
32768 0.534205 4.02
65536 2.134860 4.00You can see that for each doubling of
n, the runtime increases by roughly four times, which is what we expect for an \$ O(n^2) \$ algorithm.- Making it linear
How can we speed this up? Well, we could avoid the expensive
del operation by making a list of the survivors, instead of deleting the deceased. Consider a single trip around the circular firing squad. If there are an even number of people remaining, then the people with indexes 0, 2, 4, and so on, survive. But if there are an odd number of people remaining, then the last survivor shoots the person with index 0, so the survivors are the people with indexes 2, 4, and so on. Putting this into code form:def survivor2(n):
"""Return the survivor of a circular firing squad of n people."""
persons = list(range(1, n + 1))
while len(persons) > 1:
if len(persons) % 2 == 0:
persons = persons[::2]
else:
persons = persons[2::2]
return persons[0](You could shorten this, if you liked, using an expression like
persons[(len(persons) % 2) * 2::2], but I don't think the small reduction in code length is worth the loss of clarity.)Let's check that this is correct, by comparing the results with the original implementation:
>>> all(survivor(i) == survivor2(i) for i in range(1, 1000))
TrueNotice how useful it is for testing that we have the code organized into functions.
Now, what's the runtime of
survivor2? Again, looking at the time complexity page on the Python Wiki, we can see that the "get slice" operation takes time \$ O(k) \$ where \$ k \$ is the number of items in the slice. In this case each slice is half the length of persons, so the runtime is $$ O\left({n \over 2}\right) + O\left({n \over 4}\right) + O\left({n \over 8}\right) + \dots $$ which is \$ O(n) \$. Again, we can check that experimentally:>>> t = 1
>>> for i in range(8, 25):
... t, u = timeit(lambda:survivor2(2**i), number=1), t
... print('{:8d} {:8.6f} {:.2f}'.format(2**i, t, t / u))
...
256 0.000034 0.00
512 0.000048 1.40
1024 0.000087 1.79
2048 0.000142 1.63
4096 0.000300 2.12
8192 0.000573 1.91
16384 0.001227 2.14
32768 0.002628 2.14
65536 0.006003 2.28
131072 0.017954 2.99
262144 0.043873 2.44
524288 0.094669 2.16
1048576 0.180889 1.91
2097152 0.364302 2.01
4194304 0.743028 2.04
8388608 1.497255 2.02
16777216 3.094121 2.07Now, for each doubling of
n, the runtime increases by roughly two times, which is what we expect for an \$ O(n) \$ algorithm.- Making it polylogarithmic
Can we do even better than this? Let's look at who the survivors actually are after each trip around the firing squad:
```
from pprint import pprint
def survivors(n):
"""Print survivors after each round of circular firing squad with n people."""
persons = list(range(1, n + 1))
while len(persons) > 1:
if len(persons) % 2 == 0:
persons = persons[::2]
else:
persons = persons[2::2]
pprint(persons, compact=True)
>>> survivors(100)
[
Code Snippets
def survivor(n):
"""Return the survivor of a circular firing squad of n people."""
persons = list(range(1, n + 1))
while len(persons) > 1:
for index, _ in enumerate(persons):
del persons[(index + 1) % len(persons)]
return persons[0]>>> t = 1
>>> for i in range(8, 17):
... t, u = timeit(lambda:survivor(2**i), number=1), t
... print('{:6d} {:.6f} {:.2f}'.format(2**i, t, t / u))
...
256 0.000138 0.00
512 0.000318 2.31
1024 0.000560 1.76
2048 0.001363 2.43
4096 0.006631 4.87
8192 0.030330 4.57
16384 0.132857 4.38
32768 0.534205 4.02
65536 2.134860 4.00def survivor2(n):
"""Return the survivor of a circular firing squad of n people."""
persons = list(range(1, n + 1))
while len(persons) > 1:
if len(persons) % 2 == 0:
persons = persons[::2]
else:
persons = persons[2::2]
return persons[0]>>> all(survivor(i) == survivor2(i) for i in range(1, 1000))
True>>> t = 1
>>> for i in range(8, 25):
... t, u = timeit(lambda:survivor2(2**i), number=1), t
... print('{:8d} {:8.6f} {:.2f}'.format(2**i, t, t / u))
...
256 0.000034 0.00
512 0.000048 1.40
1024 0.000087 1.79
2048 0.000142 1.63
4096 0.000300 2.12
8192 0.000573 1.91
16384 0.001227 2.14
32768 0.002628 2.14
65536 0.006003 2.28
131072 0.017954 2.99
262144 0.043873 2.44
524288 0.094669 2.16
1048576 0.180889 1.91
2097152 0.364302 2.01
4194304 0.743028 2.04
8388608 1.497255 2.02
16777216 3.094121 2.07Context
StackExchange Code Review Q#86023, answer score: 75
Revisions (0)
No revisions yet.