MATLAB-style find() function in Python

PythonMatlabFind

Python Problem Overview


In MATLAB it is easy to find the indices of values that meet a particular condition:

>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2)     % find the indecies where this condition is true
[3, 6, 9]          % (MATLAB uses 1-based indexing)
>> a(find(a > 2))  % get the values at those locations
[3, 3, 3]

What would be the best way to do this in Python?

So far, I have come up with the following. To just get the values:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]

But if I want the index of each of those values it's a bit more complicated:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]

Is there a better way to do this in Python, especially for arbitrary conditions (not just 'val > 2')?

I found functions equivalent to MATLAB 'find' in NumPy but I currently do not have access to those libraries.

Python Solutions


Solution 1 - Python

in numpy you have where :

>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13,  1, 15,  8,  0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)

Solution 2 - Python

You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)

>>> inds
[2, 5, 8]

It's a little closer to your Matlab example, without having to load up all of numpy.

Solution 3 - Python

Or use numpy's nonzero function:

import numpy as np
a    = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds] 
array([3, 4, 5])

Solution 4 - Python

Why not just use this:

[i for i in range(len(a)) if a[i] > 2]

or for arbitrary conditions, define a function f for your condition and do:

[i for i in range(len(a)) if f(a[i])]

Solution 5 - Python

The numpy routine more commonly used for this application is numpy.where(); though, I believe it works the same as numpy.nonzero().

import numpy
a    = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)

To get the values, you can either store the indices and slice withe them:

a[inds]

or you can pass the array as an optional parameter:

numpy.where(a>2, a)

or multiple arrays:

b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)

Solution 6 - Python

To get values with arbitrary conditions, you could use filter() with a lambda function:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]

One possible way to get the indices would be to use enumerate() to build a tuple with both indices and values, and then filter that:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))

Solution 7 - Python

I've been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):

a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]

It turns out that this is much faster than:

inds = [i for (i, val) in enumerate(a) if val > 2]

It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.

Edit:

I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:

inds = np.arange( len(a) )[ a < 2 ]

Solution 8 - Python

I think I may have found one quick and simple substitute. BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.

import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a

>> [[ 1.36406736  1.45217257 -0.06896245  0.98429727 -0.59281957]]

idx = mlab.find(a<0)
print idx
type(idx)

>> [2 4]
>> np.ndarray

Best, Da

Solution 9 - Python

Matlab's find code has two arguments. John's code accounts for the first argument but not the second. For instance, if you want to know where in the index the condition is satisfied: Mtlab's function would be:

find(x>2,1)

Using John's code, all you have to do is add a [x] at the end of the indices function, where x is the index number you're looking for.

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument

which returns >>> 2, the first index to exceed 2.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
Questionuser344226View Question on Stackoverflow
Solution 1 - PythonjoaquinView Answer on Stackoverflow
Solution 2 - PythonJohnView Answer on Stackoverflow
Solution 3 - PythonvincentvView Answer on Stackoverflow
Solution 4 - PythonJasonFruitView Answer on Stackoverflow
Solution 5 - PythonryanjdillonView Answer on Stackoverflow
Solution 6 - PythonBlairView Answer on Stackoverflow
Solution 7 - PythonNateView Answer on Stackoverflow
Solution 8 - PythonDidasWView Answer on Stackoverflow
Solution 9 - PythonClayton PipkinView Answer on Stackoverflow