Spaces:
Runtime error
Runtime error
api.py
Browse files
api.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
from ferretui.eval.model_UI import load_model, inference
|
7 |
+
|
8 |
+
class interface:
|
9 |
+
def __init__(self, args, tokenizer, model, image_processor) -> None:
|
10 |
+
self.args = args
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
self.model = model
|
13 |
+
self.image_processor = image_processor
|
14 |
+
|
15 |
+
def run(self, image, qs):
|
16 |
+
output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor)
|
17 |
+
return output, image
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--model_path", type=str, default="./gemma2b-anyres")
|
24 |
+
parser.add_argument("--vision_model_path", type=str, default=None)
|
25 |
+
parser.add_argument("--model_base", type=str, default=None)
|
26 |
+
parser.add_argument("--image_path", type=str, default="")
|
27 |
+
parser.add_argument("--data_path", type=str, default="")
|
28 |
+
parser.add_argument("--answers_file", type=str, default="")
|
29 |
+
parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct",
|
30 |
+
help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]")
|
31 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
32 |
+
parser.add_argument("--chunk_idx", type=int, default=0)
|
33 |
+
parser.add_argument("--image_w", type=int, default=336) # 224
|
34 |
+
parser.add_argument("--image_h", type=int, default=336) # 224
|
35 |
+
parser.add_argument("--add_region_feature", action="store_true")
|
36 |
+
parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"])
|
37 |
+
parser.add_argument("--no_coor", action="store_true")
|
38 |
+
parser.add_argument("--temperature", type=float, default=0.01)
|
39 |
+
parser.add_argument("--top_p", type=float, default=0.3)
|
40 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
41 |
+
parser.add_argument("--max_new_tokens", type=int, default=128)
|
42 |
+
parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32'])
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
if args.data_type == 'fp16':
|
46 |
+
args.data_type = torch.float16
|
47 |
+
elif args.data_type == 'bf16':
|
48 |
+
args.data_type = torch.bfloat16
|
49 |
+
else:
|
50 |
+
args.data_type = torch.float32
|
51 |
+
|
52 |
+
tokenizer, model, image_processor, context_len = load_model(args)
|
53 |
+
gin = interface(args, tokenizer, model, image_processor)
|
54 |
+
|
55 |
+
iface = gr.Interface(
|
56 |
+
fn=gin.run,
|
57 |
+
inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")],
|
58 |
+
outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")]
|
59 |
+
)
|
60 |
+
iface.launch()
|