HiveBrain v1.2.0
Get Started
← Back to all entries
snippetpythonMinor

Generate all combinations of certain digits

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
combinationsalldigitsgeneratecertain

Problem

Related to another answer of mine on Project Euler 35, I found the need to calculate all combinations of certain digits, i.e. 1, 3, 7 and 9, below a given n. I looked around, and didn't find any really good premade solutions, but found some bits and pieces related to Cartesian products, and collated my findings into the following code:

import itertools

def all_digit_combinations(allowed_digits, maximum):
    """Yield all combinations of allowed_digits below maximum.

    For allowed_digits being a list of single digits, i.e. [1, 3, 7, 9],
    combine all variants of these digits until we pass the maximum value.
    """
    no_of_digits = 1

    while True:
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits):
            new_number = reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            if new_number < maximum:
                yield new_number
            else:
                raise StopIteration

        no_of_digits += 1

if __name__ == '__main__':
    print ', '.join(map(str, all_digit_combinations([1, 3, 7, 9], 100)))
    print ', '.join(map(str, all_digit_combinations([3, 5], 1000)))


Which indeed prints the expected output of:

1, 3, 7, 9, 11, 13, 17, 19, 31, 33, 37, 39, 71, 73, 77, 79, 91, 93, 97, 99
3, 5, 33, 35, 53, 55, 333, 335, 353, 355, 533, 535, 553, 555


I first tried using itertools.combinations_with_replacement() and other variations from itertools, but some of those versions failed to include numbers like 31 and 73. But it I could very well simply not use the correct parameters.

Can you review this code, suggesting any optimalisations or improvements?

Solution

I think the code that you have is near best it can become.

I would however use a more functional/iterator based solution.

-
while True: no_of_digits += 1 can be replaced with a for loop.
This is by using itertools.count(1). (Careful infinite generator here)

-
all_digit_combinations would imply to me, that I get all of them, even 55555555.
Instead I would make all_digit_combinations an infinite generator.

-
else: raise StopIteration is kinda ugly. Personally I would use itertool.takewhile. (In this case it's better than filter, as infinite's are involved.)

As an alternate to itertools, you could just return. But due to that being implicit, it may be disliked.

And so I would re-write it into this:

def all_digit_combinations(allowed_digits):
    return (
        reduce(lambda rst, d: rst * 10 + d, digit_tuple)
        for no_of_digits in itertools.count(1)
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
    )

def digit_combinations(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        all_digit_combinations(allowed_digits)
    )

if __name__ == '__main__':
    print ', '.join(map(str, digit_combinations([1, 3, 7, 9], 100)))
    print ', '.join(map(str, digit_combinations([3, 5], 1000)))


As a performance is concern, your code is quite a bit faster than mine.
I used the following code:

import itertools
import timeit
from functools import wraps

def all_digit_combinations_original(allowed_digits, maximum):
    no_of_digits = 1
    while True:
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits):
            new_number = reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            if new_number < maximum:
                yield new_number
            else:
                raise StopIteration
        no_of_digits += 1

def all_digit_combinations_enhanced(allowed_digits, maximum):
    for no_of_digits in itertools.count(1):
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits):
            new_number = reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            if new_number < maximum:
                yield new_number
            else:
                raise StopIteration

def all_digit_combinations_answer(allowed_digits):
    return (
        reduce(lambda rst, d: rst * 10 + d, digit_tuple)
        for no_of_digits in itertools.count(1)
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
    )

def digit_combinations_answer(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        all_digit_combinations_answer(allowed_digits)
    )

def digit_combinations_enhanced(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        (
            reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            for no_of_digits in itertools.count(1)
            for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
        )
    )

def timer_wrapper(*args, **kwargs):
    def wrapper(fn):
        @wraps(fn)
        def run():
            list(fn(*args, **kwargs))
        return run
    return wrapper

def timeit_(fn):
    print fn.__name__, timeit.timeit(fn)

if __name__ == '__main__':
    timer = timer_wrapper([1, 3, 7, 9], 100)
    timeit_(timer(all_digit_combinations_original))
    timeit_(timer(digit_combinations_answer))
    timeit_(timer(all_digit_combinations_enhanced))
    timeit_(timer(digit_combinations_enhanced))

    timer = timer_wrapper([3, 5], 1000)
    timeit_(timer(all_digit_combinations_original))
    timeit_(timer(digit_combinations_answer))
    timeit_(timer(all_digit_combinations_enhanced))
    timeit_(timer(digit_combinations_enhanced))


I obtained the following benchmarks.

all_digit_combinations_original 14.666793108
digit_combinations_answer 18.1881940365
all_digit_combinations_enhanced 15.228415966
digit_combinations_enhanced 17.9368011951

all_digit_combinations_original 14.4108960629
digit_combinations_answer 17.2161641121
all_digit_combinations_enhanced 14.8418960571
digit_combinations_enhanced 17.1169350147

Code Snippets

def all_digit_combinations(allowed_digits):
    return (
        reduce(lambda rst, d: rst * 10 + d, digit_tuple)
        for no_of_digits in itertools.count(1)
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
    )

def digit_combinations(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        all_digit_combinations(allowed_digits)
    )

if __name__ == '__main__':
    print ', '.join(map(str, digit_combinations([1, 3, 7, 9], 100)))
    print ', '.join(map(str, digit_combinations([3, 5], 1000)))
import itertools
import timeit
from functools import wraps

def all_digit_combinations_original(allowed_digits, maximum):
    no_of_digits = 1
    while True:
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits):
            new_number = reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            if new_number < maximum:
                yield new_number
            else:
                raise StopIteration
        no_of_digits += 1

def all_digit_combinations_enhanced(allowed_digits, maximum):
    for no_of_digits in itertools.count(1):
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits):
            new_number = reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            if new_number < maximum:
                yield new_number
            else:
                raise StopIteration

def all_digit_combinations_answer(allowed_digits):
    return (
        reduce(lambda rst, d: rst * 10 + d, digit_tuple)
        for no_of_digits in itertools.count(1)
        for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
    )

def digit_combinations_answer(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        all_digit_combinations_answer(allowed_digits)
    )

def digit_combinations_enhanced(allowed_digits, maximum):
    return itertools.takewhile(
        lambda x: x < maximum,
        (
            reduce(lambda rst, d: rst * 10 + d, digit_tuple)
            for no_of_digits in itertools.count(1)
            for digit_tuple in itertools.product(allowed_digits, repeat=no_of_digits)
        )
    )

def timer_wrapper(*args, **kwargs):
    def wrapper(fn):
        @wraps(fn)
        def run():
            list(fn(*args, **kwargs))
        return run
    return wrapper

def timeit_(fn):
    print fn.__name__, timeit.timeit(fn)

if __name__ == '__main__':
    timer = timer_wrapper([1, 3, 7, 9], 100)
    timeit_(timer(all_digit_combinations_original))
    timeit_(timer(digit_combinations_answer))
    timeit_(timer(all_digit_combinations_enhanced))
    timeit_(timer(digit_combinations_enhanced))

    timer = timer_wrapper([3, 5], 1000)
    timeit_(timer(all_digit_combinations_original))
    timeit_(timer(digit_combinations_answer))
    timeit_(timer(all_digit_combinations_enhanced))
    timeit_(timer(digit_combinations_enhanced))
all_digit_combinations_original 14.666793108
digit_combinations_answer 18.1881940365
all_digit_combinations_enhanced 15.228415966
digit_combinations_enhanced 17.9368011951

all_digit_combinations_original 14.4108960629
digit_combinations_answer 17.2161641121
all_digit_combinations_enhanced 14.8418960571
digit_combinations_enhanced 17.1169350147

Context

StackExchange Code Review Q#110957, answer score: 8

Revisions (0)

No revisions yet.