DEV Community

Discussion on: Understanding the Confusion Matrix

Collapse
 
manzaari profile image
Manza

hi I have 101 classes and their accuracies, I want to draw confusion matrix for them. My code is in Pytorch

A portion of my code is following

video_pred = [np.argmax(x[0]) for x in output]

video_labels = [x[1] for x in output]

print('Accuracy {:.02f}% ({})'.format(
    float(np.sum(np.array(video_pred) == np.array(video_labels))) / len(video_pred) * 100.0,
    len(video_pred)))       

#Accuracy per Class
for i in range(num_class):

    indicies_correct = np.where(np.array(video_pred) == np.array(video_labels))

    class_total = np.sum(np.array(video_labels) == i)

    class_correct = np.sum(np.array(video_pred)[indicies_correct ] == i)

    print('Class {}: Accuracy {:.02f}%'.format(i, class_correct / class_total * 100))



confusion_matrix = torch.zeros(num_class, num_class)

data_gen = enumerate(data_loader)
#with torch.no_grad():

for i, (data, label) in data_gen:
    data = data
    label = label
    outputs = model_ft(data)
    _, preds = torch.max(outputs, 1)
    for t, p in zip(label.view(-1), preds.view(-1)):
        print(t, p)
        confusion_matrix[t.long(), p.long()] += 1

print(confusion_matrix)

print(confusion_matrix.diag()/confusion_matrix.sum(1))
Collapse
 
overrideveloper profile image
Banso D. Wisdom • Edited

Hi, Manza. I've not worked with Pytorch but I believe it should be similar to the example I gave.

From the snippet you gave, since you already have the confusion matrix (confusion_matrix), what you need to do is create a dictionary of all the classes like this:

dict = { 0: 'Class 1', 1: 'Class 2', 2: 'Class 3',..., n: 'Class n' }

Then pass your confusion matrix (confusion_matrix) and the dictionary of classes (dict) to the function plot_confusion_matrix in the example I gave or a similar function, probably one you've written to your preferences.

That should output a visualization of your confusion matrix.