|
import os |
|
|
|
|
|
|
|
|
|
os.system("cd detectron2 && pip install detectron2-0.6-cp310-cp310-linux_x86_64.whl") |
|
|
|
os.system("pip install deepspeed==0.7.0") |
|
|
|
import site |
|
from importlib import reload |
|
reload(site) |
|
|
|
from PIL import Image |
|
from io import BytesIO |
|
import argparse |
|
import sys |
|
import numpy as np |
|
import torch |
|
import gradio as gr |
|
|
|
from detectron2.config import get_cfg |
|
from detectron2.data.detection_utils import read_image |
|
from detectron2.utils.logger import setup_logger |
|
|
|
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/") |
|
from centernet.config import add_centernet_config |
|
from grit.config import add_grit_config |
|
|
|
from grit.predictor import VisualizationDemo |
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") |
|
parser.add_argument( |
|
"--config-file", |
|
default="configs/GRiT_B_DenseCap_ObjectDet.yaml", |
|
metavar="FILE", |
|
help="path to config file", |
|
) |
|
parser.add_argument("--cpu", action="store_true", help="Use CPU only.") |
|
parser.add_argument( |
|
"--confidence-threshold", |
|
type=float, |
|
default=0.5, |
|
help="Minimum score for instance predictions to be shown", |
|
) |
|
parser.add_argument( |
|
"--test-task", |
|
type=str, |
|
default="", |
|
help="Choose a task to have GRiT perform", |
|
) |
|
parser.add_argument( |
|
"--opts", |
|
help="Modify config options using the command-line 'KEY VALUE' pairs", |
|
default=["MODEL.WEIGHTS", "./models/grit_b_densecap_objectdet.pth"], |
|
nargs=argparse.REMAINDER, |
|
) |
|
return parser |
|
|
|
def setup_cfg(args): |
|
cfg = get_cfg() |
|
if args.cpu: |
|
cfg.MODEL.DEVICE = "cpu" |
|
add_centernet_config(cfg) |
|
add_grit_config(cfg) |
|
cfg.merge_from_file(args.config_file) |
|
cfg.merge_from_list(args.opts) |
|
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold |
|
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = ( |
|
args.confidence_threshold |
|
) |
|
if args.test_task: |
|
cfg.MODEL.TEST_TASK = args.test_task |
|
cfg.MODEL.BEAM_SIZE = 1 |
|
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False |
|
cfg.USE_ACT_CHECKPOINT = False |
|
cfg.freeze() |
|
return cfg |
|
|
|
def predict(image_file): |
|
image_array = np.array(image_file)[:, :, ::-1] |
|
predictions, visualized_output = dense_captioning_demo.run_on_image(image_array) |
|
buffer = BytesIO() |
|
visualized_output.fig.savefig(buffer, format='png') |
|
buffer.seek(0) |
|
detections = {} |
|
predictions = predictions["instances"].to(torch.device("cpu")) |
|
|
|
for box, description, score in zip( |
|
predictions.pred_boxes, |
|
predictions.pred_object_descriptions.data, |
|
predictions.scores, |
|
): |
|
if description not in detections: |
|
detections[description] = [] |
|
detections[description].append( |
|
{ |
|
"xmin": float(box[0]), |
|
"ymin": float(box[1]), |
|
"xmax": float(box[2]), |
|
"ymax": float(box[3]), |
|
"score": float(score), |
|
} |
|
) |
|
|
|
output = { |
|
"dense_captioning_results": { |
|
"detections": detections, |
|
} |
|
} |
|
|
|
return Image.open(buffer), output |
|
|
|
|
|
|
|
args = get_parser().parse_args() |
|
args.test_task = "DenseCap" |
|
setup_logger(name="fvcore") |
|
logger = setup_logger() |
|
logger.info("Arguments: " + str(args)) |
|
|
|
cfg = setup_cfg(args) |
|
|
|
dense_captioning_demo = VisualizationDemo(cfg) |
|
|
|
demo = gr.Interface( |
|
title="Dense Captioning - GRiT", |
|
fn=predict, |
|
inputs=gr.Image(type='pil', label="Original Image"), |
|
outputs=[gr.Image(type="pil",label="Output Image"), "json"], |
|
examples=["example_1.jpg", "example_2.jpg"], |
|
) |
|
|
|
demo.launch() |
|
|