HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonModerate

Find co-occurrence of elements

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
occurrenceelementsfind

Problem

My input consists of a list of lists (not necessary of the same length), like

[
     ['a', 'b', 'c', 'd', 'b'],
     ['a', 'd', 'e', 'b'],
     ['f', 'a', 'g', 'e', 'b', 'h']
]


and I would like to have a matrix / dictionary that contains the co-occurences. The keys of the dictionary are the elements of the original matrix, while the values would be a counter. So in the previous example, the result would be

{
    'a' : Counter({'b': 4, 'c': 1, 'd': 2, 'e' : 2, 'f' : 1, 'g' : 1, 'h' : 1})
    ...
}


Here is my code:

import collections
def my_func(data):
    result = collections.defaultdict(collections.Counter)
    for l in data:
        for e in l:
            result[e].update([el for el in l if el is not e])
    return result    
my_func(data)


This works, except the fact I am not sure it's the smartest way: when updating the Counter, I am relooping over the elements of l.

EDIT: I should probably clarify that the elements in the list are not necessary chars but Python objects. I have used single letters only for faster typing.

Solution

Consider the following potential row in your data:

'abbbbbbbbbbbbbbb'


The bare minimum amount of work necessary would be to add 15 to both results - but with your loop as written, you'd add 15 to result['a']['b'], but then add 1 to result['b']['a'] 15 times. That's less than ideal.

So let's first condense the row into a counter - and then add the counter itself:

result = collections.defaultdict(lambda: collections.defaultdict(int))

for row in data:
    counts = collections.Counter(row)
    for key_from, key_to in itertools.permutations(counts, 2):
        result[key_from][key_to] += counts[key_from] * counts[key_to]


permutations() gives us all the pairs of keys that we will need to update - without repeat to ensure that a letter does not co-occur with itself. Of course, if the row is 'abcdefgh', this won't do us any good - we still make \$N^2\$ passes. But if the row is the first one I suggested, it's a big improvement.

Regardless, this:

result[e].update([el for el in l if el is not e])


Can be improved by removing the []s:

result[e].update(el for el in l if el is not e)


update() can take an iterable, so no need to create a full list.

Code Snippets

'abbbbbbbbbbbbbbb'
result = collections.defaultdict(lambda: collections.defaultdict(int))

for row in data:
    counts = collections.Counter(row)
    for key_from, key_to in itertools.permutations(counts, 2):
        result[key_from][key_to] += counts[key_from] * counts[key_to]
result[e].update([el for el in l if el is not e])
result[e].update(el for el in l if el is not e)

Context

StackExchange Code Review Q#107413, answer score: 10

Revisions (0)

No revisions yet.