patternpythonMinor
Intersection of N lists, or of an N-d array
Viewed 0 times
listsarrayintersection
Problem
I often find myself having to do
np.intersect1d() multiple times because I need to find the intersection of a number of list like or array like things. So I ended up creating a function for it. Though, if possible I'd like to simplify / and or speed up the code for when the input becomes rather large. def overlap(splits):
uniques = {item: 0 for item in list(set([i for x in splits for i in x]))}
for idx in range(len(splits)):
for val in list(set(splits[idx])):
if val in uniques.keys():
uniques[val] += 1
intersection = np.array([x for x in uniques.keys() if uniques[x] == len(splits)])
return intersectionSolution
- You can loop through a set, there is no need to wrap them in lists.
- You can loop through
splits, usingfor i in range(len(splits)): splits[i]is just long winded.
Just use
for spilt in splits:.- When checking if a key is in a dictionary, don't compare to it's
.keys(), just do'a' in dict.
- You should notice that you're forcing the dictionary to be pre-populated, alternately you could use
collections.defaultdict.
-
After changing the code to use a
defaultdict you should be able to notice that all you're doing is counting this:(item for split in splits for item in set(split))And so you could use
collections.Counter instead.-
In your final loop you use
.keys() and then use uniques[x]. Instead you should use .items(), with tuple unpacking.- Finally you should aim at improving the readability, so I generate all the items on one line. And build the np array on another.
This can result in:
from collections import Counter
def overlap(splits):
items = (item for split in splits for item in set(split))
return np.array([k for k, v in Counter(items).items() if v == len(splits)])However this does the same as
np.intersect1d with two arrays as input. The docs also say, that if you wish to intersect more than two arrays you can use:>>> from functools import reduce
>>> reduce(np.intersect1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2]))
array([3])Code Snippets
(item for split in splits for item in set(split))from collections import Counter
def overlap(splits):
items = (item for split in splits for item in set(split))
return np.array([k for k, v in Counter(items).items() if v == len(splits)])>>> from functools import reduce
>>> reduce(np.intersect1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2]))
array([3])Context
StackExchange Code Review Q#145205, answer score: 6
Revisions (0)
No revisions yet.