patternpythonMinor
Python Find All Adjacent Subsets of Set of Coins Which Have a Tails Minority
Viewed 0 times
adjacentallcoinshavesubsetstailspythonfindwhichminority
Problem
Given a sequence of heads and tails I want to find how many significant subsequences are in this sequence where the number of heads is not less than the number of tails. I want to achieve this in \$O(N \log N)\$ time.
Example input:
Example output:
Explanation:
I believe my current algorithm is \$O(N^2)\$. I solve the problem recursively, iterating with the list of coins sliced on either end.
Here is my current algorithm. Am I correct that this is \$O(N^2)\$ and is not \$O(N \log N)\$?
Example input:
[ 'H', 'T', 'H', 'T', 'T', 'H' ]Example output:
11Explanation:
{H} {H} {H}
{H, T} {T, H} {H, T} {T, H}
{H, T, H}
{H, T, H, T}
{H, T, T, H}
{H, T, H, T, T, H}I believe my current algorithm is \$O(N^2)\$. I solve the problem recursively, iterating with the list of coins sliced on either end.
Here is my current algorithm. Am I correct that this is \$O(N^2)\$ and is not \$O(N \log N)\$?
def count_sequences( data ):
print go(range(0,len(data)),data)
seen = set()
def go(rang,data):
if tuple(rang) in seen: return 0
seen.add(tuple(rang))
h = 0
summ = 0
if len(rang)==0: return 0
for i in rang:
if data[i] == 'H': h += 1
summ += go(rang[1:],data)
summ += go(rang[:-1],data)
if len(rang) == 1:
if h ==1: return 1
else: return 0
if h > (len(rang)-1)/2 :
return 1 + summ
else: return summSolution
go has these costs:O(x)
def go(rang,data):
if tuple(rang) in seen: return 0 n
seen.add(tuple(rang)) n
h = 0 1
summ = 0 1
if len(rang)==0: return 0 1
for i in rang: n (
if data[i] == 'H': h += 1 1
summ += go(rang[1:],data) n + T(n-1)
summ += go(rang[:-1],data) n + T(n-1)
if len(rang) == 1: ) 1
if h ==1: return 1 1
else: return 0 1
if h > (len(rang)-1)/2 : 1
return 1 + summ 1
else: return summ 1You call this once for each sub-tuple of
data, and there are at least this many of those:$$
1 + 2 + 3 + ... + (n-1) + n
$$
Because there are \$1\$ for length \$n\$, \$2\$ for length \$n-1\$, etc. There's one more for length 0, but that's uninteresting.
Each should be multiplied by its cost; if you consider only those that succeed (not those which are caught by
seen), they cost \$\mathcal{O}(n^2)\$ each due tofor i in rang: n (
if data[i] == 'H': h += 1 1
summ += go(rang[1:],data) n + T(n-1)
summ += go(rang[:-1],data) n + T(n-1)
)The calls are not part of this; only the looping and the slicing. This gives a cost of
$$
1(n)^2 + 2(n-1)^2 + 3(n-2)^2 + ... + (n-1)(2)^2 + n(1)^21
$$
or
$$
\sum_{k=1}^n k (n-k+1)^2 = \frac{1}{12} n (n+1) (n^2+3 n+2)
$$
which is \$\mathcal{O}(n^4)\$. So caching does improve this a lot from \$\mathcal{O}(2^n n!)\$... but not nearly enough.
Note that we could prove that it's valid to ignore failed calls by moving the
if tuple(rang) in seen check into the caller and seeing that the cost is the same.Using range objects like Python 3's
range would give costs of \$\mathcal{O}(n)\$ each, not \$\mathcal{O}(n^2)\$, since slicing would be \$\mathcal{O}(1)\$. However, the check if tuple(rang) in seen prevents this from mattering. You can see this by moving the check into the caller. If you changed this to if rang in seen using an \$\mathcal{O}(1)\$ hash, this would work and reduce the cost to \$\mathcal{O}(n^3)\$ overall. Math elided for brevity.We can test this experimentally like:
import time
def times(iters):
for n in range(iters):
heads = ['H'] * (2**n)
seen.clear()
s = time.time()
count_sequences(heads)
print((time.time() - s) * 10**5)
times(8)Giving
0.691413879395
1.00135803223
3.48091125488
14.9011611938
90.2891159058
759.696960449
7593.48869324
87769.818306We expect for \$\mathcal{O}(n^4)\$ that a doubling of size will give a factor-of-16 change; we instead see closer to a factor-of-12 change, or \$\mathcal{O}(n^{3.6})\$. This is probably partially because slices are fast relative to interpreter overhead. As \$n\$ increases, this will likely change.
When testing with Python 3 ranges, one gets
1.0251998901367188
1.1444091796875
3.5762786865234375
19.14501190185547
123.23856353759766
878.2625198364258
6614.780426025391
51359.89189147949
414633.48865509033which is almost perfectly \$\mathcal{O}(n^3)\$.
Your code should use spacing. Note how when you posted this question you used paragraphs with line breaks. The same this helps with code. You should also stick to PEP 8.
Just changing that and using brackets on
print givesdef count_sequences(data):
print(go(range(0, len(data)), data))
seen = set()
def go(range_, data):
if tuple(range_) in seen:
return 0
seen.add(tuple(range_))
h = 0
summ = 0
if not range_:
return 0
for i in range_:
if data[i] == 'H':
h += 1
summ += go(range_[1:],data)
summ += go(range_[:-1],data)
if len(range_) == 1:
if h == 1:
return 1
else:
return 0
if h > (len(range_) - 1) / 2:
return 1 + summ
else:
return summI would use a better name than
rang/range_; probably indices. However, best use two integers: start/stop. count_sequences should return rather then print. Further, a little cleanup givesdef count_sequences(data):
return go(data, 0, len(data))
seen = set()
def go(data, start, stop):
if (start, stop) in seen:
return 0
seen.add((start, stop))
h = 0
summ = 0
if start == stop:
return 0
for i in xrange(start, stop):
if data[i] == 'H':
h += 1
summ += go(data, start+1, stop)
summ += go(data, start, stop-1)
if stop - start == 1:
return h
if h > (stop - start - 1) / 2:
summ += 1
return summThen consider that
summ += go(data, start+1, stop)
summ += go(data, start, stop-1)is a loop invariant, so move it out.
And
```
for i in xrange(start, sto
Code Snippets
O(x)
def go(rang,data):
if tuple(rang) in seen: return 0 n
seen.add(tuple(rang)) n
h = 0 1
summ = 0 1
if len(rang)==0: return 0 1
for i in rang: n (
if data[i] == 'H': h += 1 1
summ += go(rang[1:],data) n + T(n-1)
summ += go(rang[:-1],data) n + T(n-1)
if len(rang) == 1: ) 1
if h ==1: return 1 1
else: return 0 1
if h > (len(rang)-1)/2 : 1
return 1 + summ 1
else: return summ 1for i in rang: n (
if data[i] == 'H': h += 1 1
summ += go(rang[1:],data) n + T(n-1)
summ += go(rang[:-1],data) n + T(n-1)
)import time
def times(iters):
for n in range(iters):
heads = ['H'] * (2**n)
seen.clear()
s = time.time()
count_sequences(heads)
print((time.time() - s) * 10**5)
times(8)0.691413879395
1.00135803223
3.48091125488
14.9011611938
90.2891159058
759.696960449
7593.48869324
87769.8183061.0251998901367188
1.1444091796875
3.5762786865234375
19.14501190185547
123.23856353759766
878.2625198364258
6614.780426025391
51359.89189147949
414633.48865509033Context
StackExchange Code Review Q#77224, answer score: 6
Revisions (0)
No revisions yet.