Ferret-UI / app.py
LeoKE178's picture
Rename api.py to app.py
1d1cae9 verified
import argparse
import torch
import gradio as gr
from PIL import Image
from io import BytesIO
from ferretui.eval.model_UI import load_model, inference
class interface:
def __init__(self, args, tokenizer, model, image_processor) -> None:
self.args = args
self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
def run(self, image, qs):
output, image = inference(self.args, image, qs, self.tokenizer, self.model, self.image_processor)
return output, image
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="./gemma2b-anyres")
parser.add_argument("--vision_model_path", type=str, default=None)
parser.add_argument("--model_base", type=str, default=None)
parser.add_argument("--image_path", type=str, default="")
parser.add_argument("--data_path", type=str, default="")
parser.add_argument("--answers_file", type=str, default="")
parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct",
help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]")
parser.add_argument("--num_chunks", type=int, default=1)
parser.add_argument("--chunk_idx", type=int, default=0)
parser.add_argument("--image_w", type=int, default=336) # 224
parser.add_argument("--image_h", type=int, default=336) # 224
parser.add_argument("--add_region_feature", action="store_true")
parser.add_argument("--region_format", type=str, default="box", choices=["point", "box", "segment", "free_shape"])
parser.add_argument("--no_coor", action="store_true")
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument("--top_p", type=float, default=0.3)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32'])
args = parser.parse_args()
if args.data_type == 'fp16':
args.data_type = torch.float16
elif args.data_type == 'bf16':
args.data_type = torch.bfloat16
else:
args.data_type = torch.float32
tokenizer, model, image_processor, context_len = load_model(args)
gin = interface(args, tokenizer, model, image_processor)
iface = gr.Interface(
fn=gin.run,
inputs=[gr.Image(type="pil", label="Input image"), gr.Textbox(label="Question")],
outputs=[gr.Textbox(label="Answer"), gr.Image(type="pil", label="Output image")]
)
iface.launch()