HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMajor

Finding all k-subset partitions

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
allsubsetpartitionsfinding

Problem

The following code generates all \$k\$-subsets of a given array. A \$k\$-subset of set \$X\$ is a partition of all the elements in \$X\$ into \$k\$ non-empty subsets.

Thus, for {1,2,3,4} a 3-subset is {{1,2},{3},{4}}.

I'm looking for improvements to the algorithm or code. Specifically, is there a better way than using copy.deepcopy? Is there some itertools magic that does this already?

import copy
arr = [1,2,3,4]

def t(k,accum,index):
    print accum,k
    if index == len(arr):
        if(k==0):
            return accum;
        else:
            return [];

    element = arr[index];
    result = []

    for set_i in range(len(accum)):
        if k>0:
            clone_new = copy.deepcopy(accum);
            clone_new[set_i].append([element]);
            result.extend( t(k-1,clone_new,index+1) );

        for elem_i in range(len(accum[set_i])):
            clone_new = copy.deepcopy(accum);
            clone_new[set_i][elem_i].append(element)
            result.extend( t(k,clone_new,index+1) );

    return result

print t(3,[[]],0);

Solution

A very efficient algorithm (Algorithm U) is described by Knuth in the Art of Computer Programming, Volume 4, Fascicle 3B to find all set partitions with a given number of blocks. Your algorithm, although simple to express, is essentially a brute-force tree search, which is not efficient.

Since Knuth's algorithm is not very concise, its implementation is lengthy as well. Note that the implementation below moves an item among the blocks one at a time and need not maintain an accumulator containing all partial results. For this reason, no copying is required.

def algorithm_u(ns, m):
    def visit(n, a):
        ps = [[] for i in xrange(m)]
        for j in xrange(n):
            ps[a[j + 1]].append(ns[j])
        return ps

    def f(mu, nu, sigma, n, a):
        if mu == 2:
            yield visit(n, a)
        else:
            for v in f(mu - 1, nu - 1, (mu + sigma) % 2, n, a):
                yield v
        if nu == mu + 1:
            a[mu] = mu - 1
            yield visit(n, a)
            while a[nu] > 0:
                a[nu] = a[nu] - 1
                yield visit(n, a)
        elif nu > mu + 1:
            if (mu + sigma) % 2 == 1:
                a[nu - 1] = mu - 1
            else:
                a[mu] = mu - 1
            if (a[nu] + sigma) % 2 == 1:
                for v in b(mu, nu - 1, 0, n, a):
                    yield v
            else:
                for v in f(mu, nu - 1, 0, n, a):
                    yield v
            while a[nu] > 0:
                a[nu] = a[nu] - 1
                if (a[nu] + sigma) % 2 == 1:
                    for v in b(mu, nu - 1, 0, n, a):
                        yield v
                else:
                    for v in f(mu, nu - 1, 0, n, a):
                        yield v

    def b(mu, nu, sigma, n, a):
        if nu == mu + 1:
            while a[nu]  mu + 1:
            if (a[nu] + sigma) % 2 == 1:
                for v in f(mu, nu - 1, 0, n, a):
                    yield v
            else:
                for v in b(mu, nu - 1, 0, n, a):
                    yield v
            while a[nu] < mu - 1:
                a[nu] = a[nu] + 1
                if (a[nu] + sigma) % 2 == 1:
                    for v in f(mu, nu - 1, 0, n, a):
                        yield v
                else:
                    for v in b(mu, nu - 1, 0, n, a):
                        yield v
            if (mu + sigma) % 2 == 1:
                a[nu - 1] = 0
            else:
                a[mu] = 0
        if mu == 2:
            yield visit(n, a)
        else:
            for v in b(mu - 1, nu - 1, (mu + sigma) % 2, n, a):
                yield v

    n = len(ns)
    a = [0] * (n + 1)
    for j in xrange(1, m + 1):
        a[n - m + j] = j - 1
    return f(m, n, 0, n, a)


Examples:

def pretty_print(parts):
    print '; '.join('|'.join(''.join(str(e) for e in loe) for loe in part) for part in parts)

>>> pretty_print(algorithm_u([1, 2, 3, 4], 3))
12|3|4; 1|23|4; 13|2|4; 1|2|34; 1|24|3; 14|2|3

>>> pretty_print(algorithm_u([1, 2, 3, 4, 5], 3))
123|4|5; 12|34|5; 1|234|5; 13|24|5; 134|2|5; 14|23|5; 124|3|5; 12|3|45; 1|23|45; 13|2|45; 1|2|345; 1|24|35; 14|2|35; 14|25|3; 1|245|3; 1|25|34; 13|25|4; 1|235|4; 12|35|4; 125|3|4; 15|23|4; 135|2|4; 15|2|34; 15|24|3; 145|2|3


Timing results:

$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4])"
100 loops, best of 3: 2.09 msec per loop

$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4, 5])"
100 loops, best of 3: 7.88 msec per loop

$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4, 5, 6])"
10 loops, best of 3: 23.6 msec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4], 3)"
10000 loops, best of 3: 26.1 usec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4, 5, 6, 7, 8], 3)"
10000 loops, best of 3: 28.1 usec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], 3)"
10000 loops, best of 3: 29.4 usec per loop


Notice that t runs much slower than algorithm_u for the same input. Furthermore, t runs exponentially slower with each extra input, whereas algorithm_u runs almost as fast for double and quadruple the input size.

Code Snippets

def algorithm_u(ns, m):
    def visit(n, a):
        ps = [[] for i in xrange(m)]
        for j in xrange(n):
            ps[a[j + 1]].append(ns[j])
        return ps

    def f(mu, nu, sigma, n, a):
        if mu == 2:
            yield visit(n, a)
        else:
            for v in f(mu - 1, nu - 1, (mu + sigma) % 2, n, a):
                yield v
        if nu == mu + 1:
            a[mu] = mu - 1
            yield visit(n, a)
            while a[nu] > 0:
                a[nu] = a[nu] - 1
                yield visit(n, a)
        elif nu > mu + 1:
            if (mu + sigma) % 2 == 1:
                a[nu - 1] = mu - 1
            else:
                a[mu] = mu - 1
            if (a[nu] + sigma) % 2 == 1:
                for v in b(mu, nu - 1, 0, n, a):
                    yield v
            else:
                for v in f(mu, nu - 1, 0, n, a):
                    yield v
            while a[nu] > 0:
                a[nu] = a[nu] - 1
                if (a[nu] + sigma) % 2 == 1:
                    for v in b(mu, nu - 1, 0, n, a):
                        yield v
                else:
                    for v in f(mu, nu - 1, 0, n, a):
                        yield v

    def b(mu, nu, sigma, n, a):
        if nu == mu + 1:
            while a[nu] < mu - 1:
                yield visit(n, a)
                a[nu] = a[nu] + 1
            yield visit(n, a)
            a[mu] = 0
        elif nu > mu + 1:
            if (a[nu] + sigma) % 2 == 1:
                for v in f(mu, nu - 1, 0, n, a):
                    yield v
            else:
                for v in b(mu, nu - 1, 0, n, a):
                    yield v
            while a[nu] < mu - 1:
                a[nu] = a[nu] + 1
                if (a[nu] + sigma) % 2 == 1:
                    for v in f(mu, nu - 1, 0, n, a):
                        yield v
                else:
                    for v in b(mu, nu - 1, 0, n, a):
                        yield v
            if (mu + sigma) % 2 == 1:
                a[nu - 1] = 0
            else:
                a[mu] = 0
        if mu == 2:
            yield visit(n, a)
        else:
            for v in b(mu - 1, nu - 1, (mu + sigma) % 2, n, a):
                yield v

    n = len(ns)
    a = [0] * (n + 1)
    for j in xrange(1, m + 1):
        a[n - m + j] = j - 1
    return f(m, n, 0, n, a)
def pretty_print(parts):
    print '; '.join('|'.join(''.join(str(e) for e in loe) for loe in part) for part in parts)

>>> pretty_print(algorithm_u([1, 2, 3, 4], 3))
12|3|4; 1|23|4; 13|2|4; 1|2|34; 1|24|3; 14|2|3

>>> pretty_print(algorithm_u([1, 2, 3, 4, 5], 3))
123|4|5; 12|34|5; 1|234|5; 13|24|5; 134|2|5; 14|23|5; 124|3|5; 12|3|45; 1|23|45; 13|2|45; 1|2|345; 1|24|35; 14|2|35; 14|25|3; 1|245|3; 1|25|34; 13|25|4; 1|235|4; 12|35|4; 125|3|4; 15|23|4; 135|2|4; 15|2|34; 15|24|3; 145|2|3
$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4])"
100 loops, best of 3: 2.09 msec per loop

$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4, 5])"
100 loops, best of 3: 7.88 msec per loop

$ python -m timeit "import test" "test.t(3, [[]], 0, [1, 2, 3, 4, 5, 6])"
10 loops, best of 3: 23.6 msec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4], 3)"
10000 loops, best of 3: 26.1 usec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4, 5, 6, 7, 8], 3)"
10000 loops, best of 3: 28.1 usec per loop

$ python -m timeit "import test" "test.algorithm_u([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], 3)"
10000 loops, best of 3: 29.4 usec per loop

Context

StackExchange Code Review Q#1526, answer score: 27

Revisions (0)

No revisions yet.