How to find the corresponding class in clf.predict_proba()

PythonMachine LearningScikit Learn

Python Problem Overview


I have a number of classes and corresponding feature vectors, and when I run predict_proba() I will get this:

classes = ['one','two','three','one','three']

feature = [[0,1,1,0],[0,1,0,1],[1,1,0,0],[0,0,0,0],[0,1,1,1]]

from sklearn.naive_bayes import BernoulliNB

clf = BernoulliNB()
clf.fit(feature,classes)
clf.predict_proba([0,1,1,0])
>> array([[ 0.48247836,  0.40709111,  0.11043053]])

I would like to get what probability that corresponds to what class. On this page it says that they are ordered by arithmetical order, i'm not 100% sure of what that means: http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC.predict_proba

Does it mean that I have go trough my training examples assign the corresponding index to the first encounter of a class, or is there a command like

clf.getClasses() = ['one','two','three']?

Python Solutions


Solution 1 - Python

Just use the .classes_ attribute of the classifier to recover the mapping. In your example that gives:

>>> clf.classes_
array(['one', 'three', 'two'], 
      dtype='|S5')

And thanks for putting a minimalistic reproduction script in your question, it makes answering really easy by just copy and pasting in a IPython shell :)

Solution 2 - Python

import pandas as pd
test = [[0,1,1,0],[1,1,1,0]]
pd.DataFrame(clf.predict_proba(test), columns=clf.classes_)

Out[2]:
         one	   three	     two
0	0.542815	0.361876	0.095309
1	0.306431	0.612863	0.080706

Solution 3 - Python

As a rule, any attribute in a learner that ends with _ is a learned one. In your case you're looking for clf.classes_.

Generally in Python, you can use the dir function to find out which attributes an object has.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
Questionuser1506145View Question on Stackoverflow
Solution 1 - PythonogriselView Answer on Stackoverflow
Solution 2 - PythonpomberView Answer on Stackoverflow
Solution 3 - Pythonlazy1View Answer on Stackoverflow