Spaces:
Runtime error
Runtime error
from PIL import Image | |
import numpy as np | |
import torch | |
from rembg import remove, new_session | |
from extras.GroundingDINO.util.inference import default_groundingdino | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def run_grounded_sam(input_image, text_prompt, box_threshold, text_threshold): | |
# run grounding dino model | |
boxes, _ = default_groundingdino( | |
image=np.array(input_image), | |
caption=text_prompt, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold | |
) | |
return boxes.xyxy | |
def generate_mask_from_image(image, mask_model, extras): | |
if image is None: | |
return | |
if 'image' in image: | |
image = image['image'] | |
if mask_model == 'sam': | |
boxes = run_grounded_sam(Image.fromarray(image), extras['sam_prompt_text'], box_threshold=extras['box_threshold'], text_threshold=extras['text_threshold']) | |
boxes = np.array([[0, 0, image.shape[1], image.shape[0]]]) if len(boxes) == 0 else boxes | |
extras['sam_prompt'] = [] | |
for idx, box in enumerate(boxes): | |
extras['sam_prompt'] += [{"type": "rectangle", "data": box.tolist()}] | |
return remove( | |
image, | |
session=new_session(mask_model, **extras), | |
only_mask=True, | |
**extras | |
) | |