DEV Community

Discussion on: Export Segment Anything neural network to ONNX: the missing parts

Collapse
 
andreygermanov profile image
Andrey Germanov • Edited

Hi, could you provide more context? If you share your code and the input image, resulted in this error, then I can say more.

At the first glance in looks like the input tensor for SAM decoder (the image embeddings, returned by the encoder) has incorrect shape. It should be (1,256,64,64), but can confirm this only after seeing the code.

Collapse
 
duxuan11 profile image
Bald man

Hello, I have solved this problem, may I ask how SAM supports the input of multiple bounding boxes?

Thread Thread
 
andreygermanov profile image
Andrey Germanov • Edited

As far as I know, it's impossible now due to the SAM model limitations. You can read some workarounds here: github.com/facebookresearch/segmen...

However, you can partially emulate this, by specifying 4 points, that belong to the object and located close to its corners. So, you can specify 8 points to define two boxes. It's not the same as two real bounding boxes, but better than nothing.

If follow the sample, provided in this article, to detect both dog and cat on the image, using the points of top left and bottom right corners of these objects, you can use the following code:

# ENCODE PROMPT (two boxes, defined as points, that belong to the objects)
input_box = np.array([155, 246, 241, 297, 279, 110, 444, 287]).reshape(4,2)
# Labels defined as for points, that belong to the objects (not as box points)
input_labels = np.array([1,1,1,1])

onnx_coord = input_box[None, :, :]
onnx_label = input_labels[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)

onnx_coord = coords.astype("float32")

# RUN DECODER TO GET MASK
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

decoder = ort.InferenceSession("vit_b_decoder.onnx")
masks,_,_ = decoder.run(None,{
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})

# POSTPROCESS MASK
mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255

# VISUALIZE MASK
img_mask = Image.fromarray(mask, "L")
img_mask
Enter fullscreen mode Exit fullscreen mode

If run this in the sam_onnx_inference.ipynb notebook, the output should be like this:

Image description