LeoKE178 commited on
Commit
f414ee8
·
1 Parent(s): b099765
Files changed (1) hide show
  1. api.py +60 -0
api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from ferretui.eval.model_UI import load_model, inference
7
+
8
+ class interface:
9
+ def __init__(self, args, tokenizer, model, image_processor) -> None:
10
+ self.args = args
11
+ self.tokenizer = tokenizer
12
+ self.model = model
13
+ self.image_processor = image_processor
14
+
15
+ def run(self, image, qs):
16
+ output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor)
17
+ return output, image
18
+
19
+
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--model_path", type=str, default="./gemma2b-anyres")
24
+ parser.add_argument("--vision_model_path", type=str, default=None)
25
+ parser.add_argument("--model_base", type=str, default=None)
26
+ parser.add_argument("--image_path", type=str, default="")
27
+ parser.add_argument("--data_path", type=str, default="")
28
+ parser.add_argument("--answers_file", type=str, default="")
29
+ parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct",
30
+ help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]")
31
+ parser.add_argument("--num_chunks", type=int, default=1)
32
+ parser.add_argument("--chunk_idx", type=int, default=0)
33
+ parser.add_argument("--image_w", type=int, default=336) # 224
34
+ parser.add_argument("--image_h", type=int, default=336) # 224
35
+ parser.add_argument("--add_region_feature", action="store_true")
36
+ parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"])
37
+ parser.add_argument("--no_coor", action="store_true")
38
+ parser.add_argument("--temperature", type=float, default=0.01)
39
+ parser.add_argument("--top_p", type=float, default=0.3)
40
+ parser.add_argument("--num_beams", type=int, default=1)
41
+ parser.add_argument("--max_new_tokens", type=int, default=128)
42
+ parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32'])
43
+ args = parser.parse_args()
44
+
45
+ if args.data_type == 'fp16':
46
+ args.data_type = torch.float16
47
+ elif args.data_type == 'bf16':
48
+ args.data_type = torch.bfloat16
49
+ else:
50
+ args.data_type = torch.float32
51
+
52
+ tokenizer, model, image_processor, context_len = load_model(args)
53
+ gin = interface(args, tokenizer, model, image_processor)
54
+
55
+ iface = gr.Interface(
56
+ fn=gin.run,
57
+ inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")],
58
+ outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")]
59
+ )
60
+ iface.launch()