HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Python Find All Adjacent Subsets of Set of Coins Which Have a Tails Minority

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
adjacentallcoinshavesubsetstailspythonfindwhichminority

Problem

Given a sequence of heads and tails I want to find how many significant subsequences are in this sequence where the number of heads is not less than the number of tails. I want to achieve this in \$O(N \log N)\$ time.

Example input:
[ 'H', 'T', 'H', 'T', 'T', 'H' ]

Example output:

11


Explanation:

{H} {H} {H}
{H, T} {T, H} {H, T} {T, H}
{H, T, H} 
{H, T, H, T} 
{H, T, T, H} 
{H, T, H, T, T, H}


I believe my current algorithm is \$O(N^2)\$. I solve the problem recursively, iterating with the list of coins sliced on either end.

Here is my current algorithm. Am I correct that this is \$O(N^2)\$ and is not \$O(N \log N)\$?

def count_sequences( data ):
    print go(range(0,len(data)),data)
seen = set()
def go(rang,data):
    if tuple(rang) in seen: return 0
    seen.add(tuple(rang))
    h = 0
    summ = 0
    if len(rang)==0: return 0
    for i in rang:
        if data[i] == 'H': h += 1
        summ += go(rang[1:],data)
        summ += go(rang[:-1],data)
    if len(rang) == 1: 
        if h ==1: return 1
        else: return 0
    if h > (len(rang)-1)/2 : 
        return 1 + summ
    else: return summ

Solution

go has these costs:

O(x)
def go(rang,data):
    if tuple(rang) in seen: return 0   n
    seen.add(tuple(rang))              n
    h = 0                              1
    summ = 0                           1
    if len(rang)==0: return 0          1
    for i in rang:                     n (
        if data[i] == 'H': h += 1         1
        summ += go(rang[1:],data)         n + T(n-1)
        summ += go(rang[:-1],data)        n + T(n-1)
    if len(rang) == 1:                 ) 1
        if h ==1: return 1             1
        else: return 0                 1
    if h > (len(rang)-1)/2 :           1
        return 1 + summ                1
    else: return summ                  1


You call this once for each sub-tuple of data, and there are at least this many of those:

$$
1 + 2 + 3 + ... + (n-1) + n
$$

Because there are \$1\$ for length \$n\$, \$2\$ for length \$n-1\$, etc. There's one more for length 0, but that's uninteresting.

Each should be multiplied by its cost; if you consider only those that succeed (not those which are caught by seen), they cost \$\mathcal{O}(n^2)\$ each due to

for i in rang:                     n (
        if data[i] == 'H': h += 1         1
        summ += go(rang[1:],data)         n + T(n-1)
        summ += go(rang[:-1],data)        n + T(n-1)
                                       )


The calls are not part of this; only the looping and the slicing. This gives a cost of

$$
1(n)^2 + 2(n-1)^2 + 3(n-2)^2 + ... + (n-1)(2)^2 + n(1)^21
$$

or

$$
\sum_{k=1}^n k (n-k+1)^2 = \frac{1}{12} n (n+1) (n^2+3 n+2)
$$

which is \$\mathcal{O}(n^4)\$. So caching does improve this a lot from \$\mathcal{O}(2^n n!)\$... but not nearly enough.

Note that we could prove that it's valid to ignore failed calls by moving the if tuple(rang) in seen check into the caller and seeing that the cost is the same.

Using range objects like Python 3's range would give costs of \$\mathcal{O}(n)\$ each, not \$\mathcal{O}(n^2)\$, since slicing would be \$\mathcal{O}(1)\$. However, the check if tuple(rang) in seen prevents this from mattering. You can see this by moving the check into the caller. If you changed this to if rang in seen using an \$\mathcal{O}(1)\$ hash, this would work and reduce the cost to \$\mathcal{O}(n^3)\$ overall. Math elided for brevity.

We can test this experimentally like:

import time

def times(iters):
    for n in range(iters):
        heads = ['H'] * (2**n)
        seen.clear()

        s = time.time()
        count_sequences(heads)
        print((time.time() - s) * 10**5)

times(8)


Giving

0.691413879395
1.00135803223
3.48091125488
14.9011611938
90.2891159058
759.696960449
7593.48869324
87769.818306


We expect for \$\mathcal{O}(n^4)\$ that a doubling of size will give a factor-of-16 change; we instead see closer to a factor-of-12 change, or \$\mathcal{O}(n^{3.6})\$. This is probably partially because slices are fast relative to interpreter overhead. As \$n\$ increases, this will likely change.

When testing with Python 3 ranges, one gets

1.0251998901367188
1.1444091796875
3.5762786865234375
19.14501190185547
123.23856353759766
878.2625198364258
6614.780426025391
51359.89189147949
414633.48865509033


which is almost perfectly \$\mathcal{O}(n^3)\$.

Your code should use spacing. Note how when you posted this question you used paragraphs with line breaks. The same this helps with code. You should also stick to PEP 8.

Just changing that and using brackets on print gives

def count_sequences(data):
    print(go(range(0, len(data)), data))

seen = set()
def go(range_, data):
    if tuple(range_) in seen:
        return 0

    seen.add(tuple(range_))

    h = 0
    summ = 0
    if not range_:
        return 0

    for i in range_:
        if data[i] == 'H':
            h += 1

        summ += go(range_[1:],data)
        summ += go(range_[:-1],data)

    if len(range_) == 1: 
        if h == 1:
            return 1
        else:
            return 0

    if h > (len(range_) - 1) / 2: 
        return 1 + summ

    else:
        return summ


I would use a better name than rang/range_; probably indices. However, best use two integers: start/stop. count_sequences should return rather then print. Further, a little cleanup gives

def count_sequences(data):
    return go(data, 0, len(data))

seen = set()
def go(data, start, stop):
    if (start, stop) in seen:
        return 0

    seen.add((start, stop))

    h = 0
    summ = 0
    if start == stop:
        return 0

    for i in xrange(start, stop):
        if data[i] == 'H':
            h += 1

        summ += go(data, start+1, stop)
        summ += go(data, start, stop-1)

    if stop - start == 1: 
        return h

    if h > (stop - start - 1) / 2: 
        summ += 1

    return summ


Then consider that

summ += go(data, start+1, stop)
summ += go(data, start, stop-1)


is a loop invariant, so move it out.

And

```
for i in xrange(start, sto

Code Snippets

O(x)
def go(rang,data):
    if tuple(rang) in seen: return 0   n
    seen.add(tuple(rang))              n
    h = 0                              1
    summ = 0                           1
    if len(rang)==0: return 0          1
    for i in rang:                     n (
        if data[i] == 'H': h += 1         1
        summ += go(rang[1:],data)         n + T(n-1)
        summ += go(rang[:-1],data)        n + T(n-1)
    if len(rang) == 1:                 ) 1
        if h ==1: return 1             1
        else: return 0                 1
    if h > (len(rang)-1)/2 :           1
        return 1 + summ                1
    else: return summ                  1
for i in rang:                     n (
        if data[i] == 'H': h += 1         1
        summ += go(rang[1:],data)         n + T(n-1)
        summ += go(rang[:-1],data)        n + T(n-1)
                                       )
import time

def times(iters):
    for n in range(iters):
        heads = ['H'] * (2**n)
        seen.clear()

        s = time.time()
        count_sequences(heads)
        print((time.time() - s) * 10**5)

times(8)
0.691413879395
1.00135803223
3.48091125488
14.9011611938
90.2891159058
759.696960449
7593.48869324
87769.818306
1.0251998901367188
1.1444091796875
3.5762786865234375
19.14501190185547
123.23856353759766
878.2625198364258
6614.780426025391
51359.89189147949
414633.48865509033

Context

StackExchange Code Review Q#77224, answer score: 6

Revisions (0)

No revisions yet.