patternpythonMinor
Intersection between arrays
Viewed 0 times
betweenarraysintersection
Problem
I'm a Python beginner (more used to code in R) and I'd like to optimize a function.
2 arrays ares filled by integers from 0 to 100. A number can't appear twice in a row and they're stored in an ascending order.
For each row of
Here is the current function, with a smaller
23.46 sec. So it should take hours if
How can I optimize this function by keeping it simple to understand?
2 arrays ares filled by integers from 0 to 100. A number can't appear twice in a row and they're stored in an ascending order.
Array1:nrow= 100 000;ncol= 5
Array2:nrow= 50 000;ncol= 5
For each row of
Array1 and each row of Array2 I need to count the number of similar values and store this result in a 3rd array.Array3:nrow= 100 000;ncol= 50 000
Here is the current function, with a smaller
array2 (50 rows instead of 50 000)array1= np.random.randint(0,100,(100000,5))
array2 = np.random.randint(0,100,(50,5))
def Intersection(array1, array2):
Intersection = np.empty([ array1.shape[0] , array2.shape[0] ], dtype=int8)
for i in range(0, array1.shape[0]):
for j in range(0, array2.shape[0]):
Intersection[i,j] = len( set(array1[i,]).intersection(array2[j,]) )
return Intersection
import time
start = time.time()
Intersection(array1,array2)
end = time.time()
print end - start23.46 sec. So it should take hours if
array2 has 50 000 rows.How can I optimize this function by keeping it simple to understand?
Solution
The main optimization I can think of is precomputing the sets. If
In addition, you could follow the Python style guide in terms of naming conventions and whitespace, which will help working together with other Python programmers. Also it's nice to loop over the data structure itself instead of over a range of indices.
This would result in for example:
If you can forget about "keeping it simple to understand" you could also convert your data to sparse matrices and take the matrix product, e.g.:
array1 has shape [m,n] and array2 is [q,p], the original code is creating a set 2mq times. You can get this down to m+q.In addition, you could follow the Python style guide in terms of naming conventions and whitespace, which will help working together with other Python programmers. Also it's nice to loop over the data structure itself instead of over a range of indices.
This would result in for example:
import numpy as np
def intersection(array1, array2):
intersection = np.empty([array1.shape[0], array2.shape[0]], dtype=np.int8)
array2_sets = map(set, array2)
for i,row1 in enumerate(array1):
set1 = set(row1)
for j, set2 in enumerate(array2_sets):
intersection[i,j] = len(set1.intersection(set2))
return intersectionIf you can forget about "keeping it simple to understand" you could also convert your data to sparse matrices and take the matrix product, e.g.:
from scipy.sparse import csr_matrix
def intersection2(A, B, sparse=True):
"""The rows of A and B must contain unique, relatively
small integers
"""
assert min(A.shape[1], B.shape[1]) < 256 # to fit inside uint8
z = max(A.max(), B.max()) + 1
def sparsify(A):
m, n = A.shape
indptr = np.arange(0, m*n+1, n)
data = np.ones(m*n, dtype=np.uint8)
return csr_matrix((data, A.ravel(), indptr), shape=(m,z))
intersection = sparsify(A) * sparsify(B).T
return intersection if sparse else intersection.todense()Code Snippets
import numpy as np
def intersection(array1, array2):
intersection = np.empty([array1.shape[0], array2.shape[0]], dtype=np.int8)
array2_sets = map(set, array2)
for i,row1 in enumerate(array1):
set1 = set(row1)
for j, set2 in enumerate(array2_sets):
intersection[i,j] = len(set1.intersection(set2))
return intersectionfrom scipy.sparse import csr_matrix
def intersection2(A, B, sparse=True):
"""The rows of A and B must contain unique, relatively
small integers
"""
assert min(A.shape[1], B.shape[1]) < 256 # to fit inside uint8
z = max(A.max(), B.max()) + 1
def sparsify(A):
m, n = A.shape
indptr = np.arange(0, m*n+1, n)
data = np.ones(m*n, dtype=np.uint8)
return csr_matrix((data, A.ravel(), indptr), shape=(m,z))
intersection = sparsify(A) * sparsify(B).T
return intersection if sparse else intersection.todense()Context
StackExchange Code Review Q#71157, answer score: 3
Revisions (0)
No revisions yet.