Home » Python » Finding the index of elements based on a condition using python list comprehension

Finding the index of elements based on a condition using python list comprehension

Posted by: admin November 30, 2017 Leave a comment


The following Python code appears to be very long winded when coming from a Matlab background

>>> a = [1, 2, 3, 1, 2, 3]
>>> [index for index,value in enumerate(a) if value > 2]
[2, 5]

When in Matlab I can write:

>> a = [1, 2, 3, 1, 2, 3];
>> find(a>2)
ans =
     3     6

Is there a short hand method of writing this in Python, or do I just stick with the long version?

Thank you for all the suggestions and explanation of the rationale for Python’s syntax.

After finding the following on the numpy website, I think I have found a solution I like:


Applying the information from that website to my problem above, would give the following:

>>> from numpy import array
>>> a = array([1, 2, 3, 1, 2, 3])
>>> b = a>2 
array([False, False, True, False, False, True], dtype=bool)
>>> r = array(range(len(b)))
>>> r(b)
[2, 5]

The following should then work (but I haven’t got a Python interpreter on hand to test it):

class my_array(numpy.array):
    def find(self, b):
        r = array(range(len(b)))
        return r(b)

>>> a = my_array([1, 2, 3, 1, 2, 3])
>>> a.find(a>2)
[2, 5]
  • In Python, you wouldn’t use indexes for this at all, but just deal with the values—[value for value in a if value > 2]. Usually dealing with indexes means you’re not doing something the best way.

  • If you do need an API similar to Matlab’s, you would use numpy, a package for multidimensional arrays and numerical math in Python which is heavily inspired by Matlab. You would be using a numpy array instead of a list.

    >>> import numpy
    >>> a = numpy.array([1, 2, 3, 1, 2, 3])
    >>> a
    array([1, 2, 3, 1, 2, 3])
    >>> numpy.where(a > 2)
    (array([2, 5]),)
    >>> a > 2
    array([False, False,  True, False, False,  True], dtype=bool)
    >>> a[numpy.where(a > 2)]
    array([3, 3])
    >>> a[a > 2]
    array([3, 3])

Another way:

>>> [i for i in range(len(a)) if a[i] > 2]
[2, 5]

In general, remember that while find is a ready-cooked function, list comprehensions are a general, and thus very powerful solution. Nothing prevents you from writing a find function in Python and use it later as you wish. I.e.:

>>> def find_indices(lst, condition):
...   return [i for i, elem in enumerate(lst) if condition(elem)]
>>> find_indices(a, lambda e: e > 2)
[2, 5]

Note that I’m using lists here to mimic Matlab. It would be more Pythonic to use generators and iterators.


Maybe another question is, “what are you going to do with those indices once you get them?” If you are going to use them to create another list, then in Python, they are an unnecessary middle step. If you want all the values that match a given condition, just use the builtin filter:

matchingVals = filter(lambda x : x>2, a)

Or write your own list comprhension:

matchingVals = [x for x in a if x > 2]

If you want to remove them from the list, then the Pythonic way is not to necessarily remove from the list, but write a list comprehension as if you were creating a new list, and assigning back in-place using the listvar[:] on the left-hand-side:

a[:] = [x for x in a if x <= 2]

Matlab supplies find because its array-centric model works by selecting items using their array indices. You can do this in Python, certainly, but the more Pythonic way is using iterators and generators, as already mentioned by @EliBendersky.


Even if it’s a late answer: I think this is still a very good question and IMHO Python (without additional libraries or toolkits like numpy) is still lacking a convenient method to access the indices of list elements according to a manually defined filter.

You could manually define a function, which provides that functionality:

def indices(list, filtr=lambda x: bool(x)):
    return [i for i,x in enumerate(list) if filtr(x)]

print(indices([1,0,3,5,1], lambda x: x==1))

Yields: [0, 4]

In my imagination the perfect way would be making a child class of list and adding the indices function as class method. In this way only the filter method would be needed:

class MyList(list):
    def __init__(self, *args):
        list.__init__(self, *args)
    def indices(self, filtr=lambda x: bool(x)):
        return [i for i,x in enumerate(self) if filtr(x)]

my_list = MyList([1,0,3,5,1])
my_list.indices(lambda x: x==1)

I elaborated a bit more on that topic here:


For me it works well:

>>> import numpy as np
>>> a = np.array([1, 2, 3, 1, 2, 3])
>>> np.where(a > 2)[0]
[2 5]