segmentation anything model is popular lately.
to try it's interacte usage, i made simple matplotlib script.
Below script use webagg backend, so after you run script new brower tab will open.
you can draw bounding box with mouse dragging, after you release mouse, automatically sam model will run on it
installation
conda create --name sam pip python=3.10
conda activate sam
conda install -y pytorch torchvision torchaudio -c pytorch
pip install transformers
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install datasets
pip install matplotlib
pip install tornado
pip install jupyter
pip install notebook
load library
import matplotlib
import numpy as np
from datasets import load_dataset
matplotlib.use('WebAgg')
import random
import cv2
import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.widgets import RectangleSelector
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
download weight
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
load model
# %% load model
sam_checkpoint = "sam_vit_b_01ec64.pth"
# for m1 mac
model_type = "vit_b"
device = "mps"
# # use linux with gpu
# model_type = "vit_h"
# device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
load image
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir="./dummy_data/", split='train[:10]')
example = ds[0]
img = Image.open(example['image_path'])
image = np.array(img)
img
run single image
predictor.set_image(image)
# %% select function
def on_select(eclick, erelease):
x1, y1 = eclick.xdata, eclick.ydata
x2, y2 = erelease.xdata, erelease.ydata
print(f"Rectangle selected from ({x1}, {y1}) to ({x2}, {y2})")
input_box = np.array([x1,y1,x2,y2])
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
mask = masks[0]
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax = plt.gca()
ax.imshow(mask_image)
plt.close('all')
fig, ax = plt.subplots()
ax.imshow(image)
rect_selector = RectangleSelector(ax, on_select, useblit=True, button=[1], minspanx=5, minspany=5, spancoords='pixels', interactive=True,
)
plt.show()
Top comments (0)