HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Extracting an arbitrary element from each row of a np.array

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
eacharrayarbitraryelementextractingfromrow

Problem

I have a large numpy array of shape (n,m). I need to extract one element from each row, and I have another array of shape (n,) that gives the column index of the element I need. The following code does this, but it requires an explicit loop (in the form of a list comprehension.):

import numpy as np

arr = np.array(range(12))
arr = arr.reshape((4,3))
keys = np.array([1,0,1,2])
#This is the line that I'd like to optimize
answers = np.array([arr[i,keys[i]] for i in range(len(keys))])
print(answers)
# [ 1  3  7 11]


Is there a built-in numpy (or pandas?) function that could do this more efficiently?

Solution

The best way to do this is how @GarethRees suggested in the comments:

>>> arr[np.arange(4), keys]
array([1, 3, 7, 11])


There is another (not as good) solution that is a better alternative to using arr[i, keys[i]] for i in range(len(keys)). Whenever you want both the index and the item from some iterable, you should generally use the enumerate function:

>>> np.array([arr[i, item] for i, item in enumerate(keys)])
array([1, 3, 7, 11])

Code Snippets

>>> arr[np.arange(4), keys]
array([1, 3, 7, 11])
>>> np.array([arr[i, item] for i, item in enumerate(keys)])
array([1, 3, 7, 11])

Context

StackExchange Code Review Q#109425, answer score: 4

Revisions (0)

No revisions yet.