patternpythonMinor
Extracting an arbitrary element from each row of a np.array
Viewed 0 times
eacharrayarbitraryelementextractingfromrow
Problem
I have a large numpy array of shape (n,m). I need to extract one element from each row, and I have another array of shape (n,) that gives the column index of the element I need. The following code does this, but it requires an explicit loop (in the form of a list comprehension.):
Is there a built-in numpy (or pandas?) function that could do this more efficiently?
import numpy as np
arr = np.array(range(12))
arr = arr.reshape((4,3))
keys = np.array([1,0,1,2])
#This is the line that I'd like to optimize
answers = np.array([arr[i,keys[i]] for i in range(len(keys))])
print(answers)
# [ 1 3 7 11]Is there a built-in numpy (or pandas?) function that could do this more efficiently?
Solution
The best way to do this is how @GarethRees suggested in the comments:
There is another (not as good) solution that is a better alternative to using
>>> arr[np.arange(4), keys]
array([1, 3, 7, 11])There is another (not as good) solution that is a better alternative to using
arr[i, keys[i]] for i in range(len(keys)). Whenever you want both the index and the item from some iterable, you should generally use the enumerate function:>>> np.array([arr[i, item] for i, item in enumerate(keys)])
array([1, 3, 7, 11])Code Snippets
>>> arr[np.arange(4), keys]
array([1, 3, 7, 11])>>> np.array([arr[i, item] for i, item in enumerate(keys)])
array([1, 3, 7, 11])Context
StackExchange Code Review Q#109425, answer score: 4
Revisions (0)
No revisions yet.