HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Intersection between arrays

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
betweenarraysintersection

Problem

I'm a Python beginner (more used to code in R) and I'd like to optimize a function.

2 arrays ares filled by integers from 0 to 100. A number can't appear twice in a row and they're stored in an ascending order.

  • Array1: nrow = 100 000; ncol = 5



  • Array2: nrow = 50 000; ncol = 5



For each row of Array1 and each row of Array2 I need to count the number of similar values and store this result in a 3rd array.

  • Array3: nrow = 100 000; ncol = 50 000



Here is the current function, with a smaller array2 (50 rows instead of 50 000)

array1= np.random.randint(0,100,(100000,5))
array2 = np.random.randint(0,100,(50,5))

def Intersection(array1, array2):
    Intersection = np.empty([ array1.shape[0]  , array2.shape[0] ], dtype=int8)
    for i in range(0, array1.shape[0]):
        for j in range(0, array2.shape[0]):
            Intersection[i,j] = len( set(array1[i,]).intersection(array2[j,]) )
    return Intersection

import time
start = time.time()
Intersection(array1,array2)
end = time.time()
print end - start


23.46 sec. So it should take hours if array2 has 50 000 rows.

How can I optimize this function by keeping it simple to understand?

Solution

The main optimization I can think of is precomputing the sets. If array1 has shape [m,n] and array2 is [q,p], the original code is creating a set 2mq times. You can get this down to m+q.

In addition, you could follow the Python style guide in terms of naming conventions and whitespace, which will help working together with other Python programmers. Also it's nice to loop over the data structure itself instead of over a range of indices.

This would result in for example:

import numpy as np

def intersection(array1, array2):
    intersection = np.empty([array1.shape[0], array2.shape[0]], dtype=np.int8)
    array2_sets = map(set, array2)
    for i,row1 in enumerate(array1):
        set1 = set(row1)
        for j, set2 in enumerate(array2_sets):
            intersection[i,j] = len(set1.intersection(set2))
    return intersection


If you can forget about "keeping it simple to understand" you could also convert your data to sparse matrices and take the matrix product, e.g.:

from scipy.sparse import csr_matrix

def intersection2(A, B, sparse=True):
    """The rows of A and B must contain unique, relatively 
    small integers
    """
    assert min(A.shape[1], B.shape[1]) < 256  # to fit inside uint8
    z = max(A.max(), B.max()) + 1

    def sparsify(A):
        m, n = A.shape
        indptr = np.arange(0, m*n+1, n)
        data = np.ones(m*n, dtype=np.uint8)
        return csr_matrix((data, A.ravel(), indptr), shape=(m,z))

    intersection = sparsify(A) * sparsify(B).T
    return intersection if sparse else intersection.todense()

Code Snippets

import numpy as np

def intersection(array1, array2):
    intersection = np.empty([array1.shape[0], array2.shape[0]], dtype=np.int8)
    array2_sets = map(set, array2)
    for i,row1 in enumerate(array1):
        set1 = set(row1)
        for j, set2 in enumerate(array2_sets):
            intersection[i,j] = len(set1.intersection(set2))
    return intersection
from scipy.sparse import csr_matrix

def intersection2(A, B, sparse=True):
    """The rows of A and B must contain unique, relatively 
    small integers
    """
    assert min(A.shape[1], B.shape[1]) < 256  # to fit inside uint8
    z = max(A.max(), B.max()) + 1

    def sparsify(A):
        m, n = A.shape
        indptr = np.arange(0, m*n+1, n)
        data = np.ones(m*n, dtype=np.uint8)
        return csr_matrix((data, A.ravel(), indptr), shape=(m,z))

    intersection = sparsify(A) * sparsify(B).T
    return intersection if sparse else intersection.todense()

Context

StackExchange Code Review Q#71157, answer score: 3

Revisions (0)

No revisions yet.