File size: 2,652 Bytes
f414ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import torch
import gradio as gr
from PIL import Image
from io import BytesIO
from ferretui.eval.model_UI import load_model, inference

class interface:
    def __init__(self, args, tokenizer, model, image_processor) -> None:
        self.args = args
        self.tokenizer = tokenizer
        self.model = model
        self.image_processor = image_processor

    def run(self, image, qs):
        output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor)
        return output, image



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="./gemma2b-anyres")
    parser.add_argument("--vision_model_path", type=str, default=None)
    parser.add_argument("--model_base", type=str, default=None)
    parser.add_argument("--image_path", type=str, default="")
    parser.add_argument("--data_path", type=str, default="")
    parser.add_argument("--answers_file", type=str, default="")
    parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct",
                        help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]")
    parser.add_argument("--num_chunks", type=int, default=1)
    parser.add_argument("--chunk_idx", type=int, default=0)
    parser.add_argument("--image_w", type=int, default=336)  #  224
    parser.add_argument("--image_h", type=int, default=336)  #  224
    parser.add_argument("--add_region_feature", action="store_true")
    parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"])
    parser.add_argument("--no_coor", action="store_true")
    parser.add_argument("--temperature", type=float, default=0.01)
    parser.add_argument("--top_p", type=float, default=0.3)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32'])
    args = parser.parse_args()

    if args.data_type == 'fp16':
        args.data_type = torch.float16
    elif args.data_type == 'bf16':
        args.data_type = torch.bfloat16
    else:
        args.data_type = torch.float32
    
    tokenizer, model, image_processor, context_len = load_model(args)
    gin = interface(args, tokenizer, model, image_processor)

    iface = gr.Interface(
        fn=gin.run,
        inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")],
        outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")]
    )
    iface.launch()