Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
import gradio as gr
|
8 |
+
import subprocess
|
9 |
+
from llava.model.builder import load_pretrained_model
|
10 |
+
from llava.mm_utils import get_model_name_from_path
|
11 |
+
from llava.eval.run_llava import eval_model
|
12 |
+
|
13 |
+
# Load the LLaVA model and processor
|
14 |
+
llava_model_path = "/My_new_LLaVA/llava-fine_tune_model"
|
15 |
+
|
16 |
+
# Load the LLaVA-Med model and processor
|
17 |
+
llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model"
|
18 |
+
|
19 |
+
# Args class to store arguments for LLaVA models
|
20 |
+
class Args:
|
21 |
+
def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):
|
22 |
+
self.model_path = model_path
|
23 |
+
self.model_base = model_base
|
24 |
+
self.model_name = model_name
|
25 |
+
self.query = query
|
26 |
+
self.image_path = image_path
|
27 |
+
self.conv_mode = conv_mode
|
28 |
+
self.image_file = image_file
|
29 |
+
self.sep = sep
|
30 |
+
self.temperature = temperature
|
31 |
+
self.top_p = top_p
|
32 |
+
self.num_beams = num_beams
|
33 |
+
self.max_new_tokens = max_new_tokens
|
34 |
+
|
35 |
+
# # Function to predict using Idefics2
|
36 |
+
# def predict_idefics2(image, question, temperature, max_tokens):
|
37 |
+
# image = image.convert("RGB")
|
38 |
+
# images = [image]
|
39 |
+
|
40 |
+
# messages = [
|
41 |
+
# {
|
42 |
+
# "role": "user",
|
43 |
+
# "content": [
|
44 |
+
# {"type": "image"},
|
45 |
+
# {"type": "text", "text": question}
|
46 |
+
# ]
|
47 |
+
# }
|
48 |
+
# ]
|
49 |
+
# input_text = idefics2_processor.apply_chat_template(messages, add_generation_prompt=False).strip()
|
50 |
+
|
51 |
+
# inputs = idefics2_processor(text=[input_text], images=images, return_tensors="pt", padding=True).to("cuda:0")
|
52 |
+
|
53 |
+
# with torch.no_grad():
|
54 |
+
# outputs = idefics2_model.generate(**inputs, max_length=max_tokens, max_new_tokens=max_tokens, temperature=temperature)
|
55 |
+
|
56 |
+
# predictions = idefics2_processor.decode(outputs[0], skip_special_tokens=True)
|
57 |
+
|
58 |
+
# return predictions
|
59 |
+
|
60 |
+
# Function to predict using LLaVA
|
61 |
+
def predict_llava(image, question, temperature, max_tokens):
|
62 |
+
# Save the image temporarily
|
63 |
+
image.save("temp_image.jpg")
|
64 |
+
|
65 |
+
# Setup evaluation arguments
|
66 |
+
args = Args(
|
67 |
+
model_path=llava_model_path,
|
68 |
+
model_base=None,
|
69 |
+
model_name=get_model_name_from_path(llava_model_path),
|
70 |
+
query=question,
|
71 |
+
image_path="temp_image.jpg",
|
72 |
+
conv_mode=None,
|
73 |
+
image_file="temp_image.jpg",
|
74 |
+
sep=",",
|
75 |
+
temperature=temperature,
|
76 |
+
top_p=None,
|
77 |
+
num_beams=1,
|
78 |
+
max_new_tokens=max_tokens
|
79 |
+
)
|
80 |
+
|
81 |
+
# Generate the answer using the selected model
|
82 |
+
output = eval_model(args)
|
83 |
+
|
84 |
+
return output
|
85 |
+
|
86 |
+
# Function to predict using LLaVA-Med
|
87 |
+
def predict_llava_med(image, question, temperature, max_tokens):
|
88 |
+
# Save the image temporarily
|
89 |
+
image_path = "temp_image_med.jpg"
|
90 |
+
image.save(image_path)
|
91 |
+
|
92 |
+
# Command to run the LLaVA-Med model
|
93 |
+
command = [
|
94 |
+
"python", "-m", "llava.eval.run_llava",
|
95 |
+
"--model-name", llava_med_model_path,
|
96 |
+
"--image-file", image_path,
|
97 |
+
"--query", question,
|
98 |
+
"--temperature", str(temperature),
|
99 |
+
"--max-new-tokens", str(max_tokens)
|
100 |
+
]
|
101 |
+
|
102 |
+
# Execute the command and capture the output
|
103 |
+
result = subprocess.run(command, capture_output=True, text=True)
|
104 |
+
|
105 |
+
return result.stdout.strip() # Return the output as text
|
106 |
+
|
107 |
+
# Main prediction function
|
108 |
+
def predict(model_name, image, text, temperature, max_tokens):
|
109 |
+
if model_name == "LLaVA":
|
110 |
+
return predict_llava(image, text, temperature, max_tokens)
|
111 |
+
elif model_name == "LLaVA-Med":
|
112 |
+
return predict_llava_med(image, text, temperature, max_tokens)
|
113 |
+
|
114 |
+
# Define the Gradio interface
|
115 |
+
interface = gr.Interface(
|
116 |
+
fn=predict,
|
117 |
+
inputs=[
|
118 |
+
gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"),
|
119 |
+
gr.Image(type="pil", label="Input Image"),
|
120 |
+
gr.Textbox(label="Input Text"),
|
121 |
+
gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"),
|
122 |
+
gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"),
|
123 |
+
],
|
124 |
+
outputs=gr.Textbox(label="Output Text"),
|
125 |
+
title="Multimodal LLM Interface",
|
126 |
+
description="Switch between models and adjust parameters.",
|
127 |
+
)
|
128 |
+
|
129 |
+
# Launch the Gradio interface
|
130 |
+
interface.launch()
|