DEV Community

Cover image for Export Segment Anything neural network to ONNX: the missing parts

Export Segment Anything neural network to ONNX: the missing parts

Andrey Germanov on November 15, 2023

Table of Contents Introduction What is a problem ? Diving to the SAM model structure Export SAM to ONNX - the right way     Export the i...
Collapse
 
cumminsorgs profile image
Matt Cummins

Thanks so much for sharing this!

Collapse
 
ducke profile image
kieunHyeong

Thank you! so much!

Collapse
 
davidt profile image
david • Edited

Hi, I tried to export onnx file for the vit_h model by modifying the line:

sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")

to

sam = sam_model_registry["vit_h"](checkpoint="./checkpoint/sam_vit_h_4b8939.pth")

It then generate 455 files. Some of them are:

Image description

And the encoder onnx file is only about 1Mb size (vs 350Mb of the vit_b )

Did I miss changing anything in the script?

Collapse
 
neverduft profile image
Filip • Edited

Encountered the same issue, turns out its an onnx limitation, models over 2Gb will be exported like this. It's still usable however, I followed the rest of the tutorial and all worked. As a workaround I went and quantized the split model and received a nice 600mb encoder in onnx format, as far as I know the quality loss should be minimal. You can give this a try:

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

out_onnx_model_quantized_path = "vit_h_encoder_quantized.onnx"
input_onnx_encoder_path = "vit_h_encoder.onnx" # Path to the encoder onnx file, all the other blocks files must be in the same dir!

quantize_dynamic(
    model_input="input_onnx_encoder_path ",
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
Enter fullscreen mode Exit fullscreen mode

Or simply use one of the smaller SAM models.

Collapse
 
davidt profile image
david

Thank you!

Collapse
 
duxuan11 profile image
Bald man

I have a problem and I wonder if you can help me solve it?

[ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_151' Status Message: D:\a\_work\1\s\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,8,256}, requested shape:{1,10,8,32}
  File "C:\Users\ASUS\Desktop\avatarget\TinySAM-main\TinySAM-main\onnx.py", line 157, in <module>
    outputs = decoder.run(None,{
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_151' Status Message: D:\a\_work\1\s\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,8,256}, requested shape:{1,10,8,32}
Enter fullscreen mode Exit fullscreen mode
Collapse
 
andreygermanov profile image
Andrey Germanov • Edited

Hi, could you provide more context? If you share your code and the input image, resulted in this error, then I can say more.

At the first glance in looks like the input tensor for SAM decoder (the image embeddings, returned by the encoder) has incorrect shape. It should be (1,256,64,64), but can confirm this only after seeing the code.

Collapse
 
duxuan11 profile image
Bald man

Hello, I have solved this problem, may I ask how SAM supports the input of multiple bounding boxes?

Thread Thread
 
andreygermanov profile image
Andrey Germanov • Edited

As far as I know, it's impossible now due to the SAM model limitations. You can read some workarounds here: github.com/facebookresearch/segmen...

However, you can partially emulate this, by specifying 4 points, that belong to the object and located close to its corners. So, you can specify 8 points to define two boxes. It's not the same as two real bounding boxes, but better than nothing.

If follow the sample, provided in this article, to detect both dog and cat on the image, using the points of top left and bottom right corners of these objects, you can use the following code:

# ENCODE PROMPT (two boxes, defined as points, that belong to the objects)
input_box = np.array([155, 246, 241, 297, 279, 110, 444, 287]).reshape(4,2)
# Labels defined as for points, that belong to the objects (not as box points)
input_labels = np.array([1,1,1,1])

onnx_coord = input_box[None, :, :]
onnx_label = input_labels[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)

onnx_coord = coords.astype("float32")

# RUN DECODER TO GET MASK
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

decoder = ort.InferenceSession("vit_b_decoder.onnx")
masks,_,_ = decoder.run(None,{
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})

# POSTPROCESS MASK
mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255

# VISUALIZE MASK
img_mask = Image.fromarray(mask, "L")
img_mask
Enter fullscreen mode Exit fullscreen mode

If run this in the sam_onnx_inference.ipynb notebook, the output should be like this:

Image description

Collapse
 
davidt profile image
david • Edited

This is great! Thank you!

One question. In the section Encode the prompt , what does the line input_labels = np.array([2,3]) mean when the input is a bounding box? In the official instruction, I didn't see any label required for box input.

Collapse
 
andreygermanov profile image
Andrey Germanov • Edited

Each coordinate (x,y) should have a label, so it means that top left corner of bounding box has label 2 and bottom right corner has label 3.

Collapse
 
davidt profile image
david

Thanks!

Collapse
 
fragakos profile image
fragakos

Great works, thanks.

Can the automatic mask generator be exported to onnx though?

Collapse
 
andreygermanov profile image
Andrey Germanov

Hello, thank you.

The automatic mask generator is not a different model, that can be exported to ONNX. It's a Python class, that uses the same model many times for different points of the image and combines the resulted masks.

Collapse
 
fragakos profile image
fragakos

Is there a guide on how to use that in the best way?

Thread Thread
 
andreygermanov profile image
Andrey Germanov

github.com/facebookresearch/segmen...

Also, the source code of SamAutomaticMaskGenerator class can help to understand how it exactly works: github.com/facebookresearch/segmen...

Thread Thread
 
fragakos profile image
fragakos

Yes I saw that, but how is it done using the onnx type files? Maybe can you make a quide on that too?

Collapse
 
fragakos profile image
fragakos

Is it possible to optimize the encoder for GPU?

Collapse
 
fragakos profile image
fragakos

I think it is time to try the new SAM2, we need your guidance! <3

Collapse
 
abhinav2712 profile image
Abhinav Kumar

Wanted to ask if SAM model can take tiled images (like openseadragon themed images) , If yes , can you provide some resources or references to apply that?
Thanks,

Collapse
 
andreygermanov profile image
Andrey Germanov • Edited

The SAM does not take images, it takes tensors of 1024x1024x3 size. The image should be converted to a tensor before passing to the SAM model. I am not familiar with OpenSeadragon images, but if they can be exported to standard PNG or JPG and then converted to the tensors, as described in this article, then yes.

Collapse
 
vincenzo_laforgia_2a3403 profile image
Vincenzo La Forgia

Hi great guide, do you think it is possible to export to onnx segGPT? could you make a guide for it too?