Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to retrieve the labels used in a segmentation mask in AWS Sagemaker

From a segmentation mask, I am trying to retrieve what labels are being represented in the mask.

This is the image I am running through a semantic segmentation model in AWS Sagemaker.

Motorbike and everything else background

Code for making prediction and displaying mask.

from sagemaker.predictor import json_serializer, json_deserializer, RealTimePredictor
from sagemaker.content_types import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON

%%time
ss_predict = sagemaker.RealTimePredictor(endpoint=ss_model.endpoint_name, 
                                     sagemaker_session=sess,
                                    content_type = 'image/jpeg',
                                    accept = 'image/png')

return_img = ss_predict.predict(img)

from PIL import Image
import numpy as np
import io

num_labels = 21
mask = np.array(Image.open(io.BytesIO(return_img)))
plt.imshow(mask, vmin=0, vmax=num_labels-1, cmap='jet')
plt.show()

This image is the segmentation mask that was created and it represents the motorbike and everything else is the background.

[Segmented mask[2]

As you can see from the code there are 21 possible labels and 2 were used in the mask, one for the motorbike and another for the background. What I would like to figure out now is how to print which labels were actually used in this mask out of the 21 possible options?

Please let me know if you need any further information and any help is much appreciated.

like image 218
Dre Avatar asked May 28 '20 05:05

Dre


1 Answers

Somewhere you should have a mapping from label integers to label classes, e.g.

label_map = {0: 'background', 1: 'motorbike', 2: 'train', ...}

If you are using the Pascal VOC dataset, that would be (1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle, 6=bus, 7=car , 8=cat, 9=chair, 10=cow, 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person, 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor) - see here: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/segexamples/index.html

Then you can simply use that map:

used_classes = np.unique(mask)
for cls in used_classes:
    print("Found class: {}".format(label_map[cls]))
like image 133
mpaepper Avatar answered Nov 11 '22 18:11

mpaepper