How Pytorch Tensor get the index of specific value

PythonPytorch

Python Problem Overview


With python lists, we can do:

a = [1, 2, 3]
assert a.index(2) == 1

How can a pytorch tensor find the .index() directly?

Python Solutions


Solution 1 - Python

I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])

This piece of code returns

> 1 > > [torch.LongTensor of size 1x1]

Solution 2 - Python

For multidimensional tensors you can do:

(tensor == target_value).nonzero(as_tuple=True)

The resulting tensor will be of shape number_of_matches x tensor_dimension. For example, say tensor is a 3 x 4 tensor (that means the dimension is 2), the result will be a 2D-tensor with the indexes for the matches in the rows.

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])

Solution 3 - Python

Based on others' answers:

t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())

Solution 4 - Python

Can be done by converting to numpy as follows

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

Solution 5 - Python

The answers already given are great but they don't handle when I tried it when there is no match. For that see this:

def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
    """
    Returns generalized index (i.e. location/coordinate) of the first occurence of value
    in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
    of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
    If there are multiple occurences then you need to choose which one you want with ith_index.
    e.g. ith_index=0 gives first occurence.

    Reference: https://stackoverflow.com/a/67175757/1601580
    :return:
    """
    # bool tensor of where value occurred
    places_where_value_occurs = (tensor == value)
    # get matches as a "coordinate list" where occurence happened
    matches = (tensor == value).nonzero()  # [number_of_matches, tensor_dimension]
    if matches.size(0) == 0:  # no matches
        return -1
    else:
        # get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
        index = matches[ith_match]
        return index

credit to this great answer: https://stackoverflow.com/a/67175757/1601580

Solution 6 - Python

x = torch.Tensor([11, 22, 33, 22])
print((x==22).nonzero().squeeze())

> tensor([1, 3])

Solution 7 - Python

for finding index of an element in 1d tensor/array Example

mat=torch.tensor([1,8,5,3])

to find index of 5

five=5

numb_of_col=4
for o in range(numb_of_col):
   if mat[o]==five:
     print(torch.tensor([o]))

To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)

Example

mat=torch.tensor([[1,2],[4,3])
#to find index of 2

five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
   if mat[o] == five:
     print(torch.tensor([o]))    

Solution 8 - Python

For floating point tensors, I use this to get the index of the element in the tensor.

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())

Solution 9 - Python

    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionHan BingView Question on Stackoverflow
Solution 1 - PythonManuel LagunasView Answer on Stackoverflow
Solution 2 - PythondopexxxView Answer on Stackoverflow
Solution 3 - PythonFrank XuView Answer on Stackoverflow
Solution 4 - PythonvladView Answer on Stackoverflow
Solution 5 - PythonCharlie ParkerView Answer on Stackoverflow
Solution 6 - PythonMinionsView Answer on Stackoverflow
Solution 7 - PythonRANJITH TView Answer on Stackoverflow
Solution 8 - PythonGiang NguyễnView Answer on Stackoverflow
Solution 9 - PythonMohanrajView Answer on Stackoverflow