HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Determining whether a loop iterated at least one element in generator function

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
determiningloopfunctiongeneratoroneelementleastiteratedwhether

Problem

I have a case in which I need to determine whether a for loop yielded at least one element:

@property
def source_customers(self):
    """Yields source customers"""
    for config in RealEstatesCustomization.select().where(
            RealEstatesCustomization.customer == self.cid):
        yield config.source_customer
    # Yield self.customer iff database
    # query returned no results
    try:
        config
    except UnboundLocalError:
        yield self.customer


Is this a pythonic way to go or shall I rather use a flag for that like so:

@property
def source_customers(self):
    """Yields source customers"""
    yielded = False
    for config in RealEstatesCustomization.select().where(
            RealEstatesCustomization.customer == self.cid):
        yield config.source_customer
        yielded = True
    # Yield self.customer iff database
    # query returned no results
    if not yielded:
        yield self.customer


What is preferable?

Solution

Separate Your Concerns

The function at the moment does two things: finds matching customers and checks that it's nonempty. Let's instead add a default checker:

from itertools import tee

def iter_or_default(iterable, dflt):
    orig, test = tee(iter(iterable))
    try:
        next(test)
        # if we're still here, the iterable is non-empty
        yield from orig
    except StopIteration:
        # iterable was empty
        yield dflt


And now, we just use that:

@property
def source_customers(self):
    """Yields source customers, or the instance customer if none found"""

    source_customers = (c.source_customer for c in 
         RealEstatesCustomization.select().where(
            RealEstatesCustomization.customer == self.cid))
    yield from iter_or_default(source_customers, self.customer)

Code Snippets

from itertools import tee

def iter_or_default(iterable, dflt):
    orig, test = tee(iter(iterable))
    try:
        next(test)
        # if we're still here, the iterable is non-empty
        yield from orig
    except StopIteration:
        # iterable was empty
        yield dflt
@property
def source_customers(self):
    """Yields source customers, or the instance customer if none found"""

    source_customers = (c.source_customer for c in 
         RealEstatesCustomization.select().where(
            RealEstatesCustomization.customer == self.cid))
    yield from iter_or_default(source_customers, self.customer)

Context

StackExchange Code Review Q#113149, answer score: 4

Revisions (0)

No revisions yet.