Varu96 commited on
Commit
0abfefa
·
verified ·
1 Parent(s): 817868c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py CHANGED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+ from transformers import AutoProcessor, AutoModelForVision2Seq
5
+ import torch
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import subprocess
9
+ from llava.model.builder import load_pretrained_model
10
+ from llava.mm_utils import get_model_name_from_path
11
+ from llava.eval.run_llava import eval_model
12
+
13
+ # Load the LLaVA model and processor
14
+ llava_model_path = "/My_new_LLaVA/llava-fine_tune_model"
15
+
16
+ # Load the LLaVA-Med model and processor
17
+ llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model"
18
+
19
+ # Args class to store arguments for LLaVA models
20
+ class Args:
21
+ def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):
22
+ self.model_path = model_path
23
+ self.model_base = model_base
24
+ self.model_name = model_name
25
+ self.query = query
26
+ self.image_path = image_path
27
+ self.conv_mode = conv_mode
28
+ self.image_file = image_file
29
+ self.sep = sep
30
+ self.temperature = temperature
31
+ self.top_p = top_p
32
+ self.num_beams = num_beams
33
+ self.max_new_tokens = max_new_tokens
34
+
35
+ # # Function to predict using Idefics2
36
+ # def predict_idefics2(image, question, temperature, max_tokens):
37
+ # image = image.convert("RGB")
38
+ # images = [image]
39
+
40
+ # messages = [
41
+ # {
42
+ # "role": "user",
43
+ # "content": [
44
+ # {"type": "image"},
45
+ # {"type": "text", "text": question}
46
+ # ]
47
+ # }
48
+ # ]
49
+ # input_text = idefics2_processor.apply_chat_template(messages, add_generation_prompt=False).strip()
50
+
51
+ # inputs = idefics2_processor(text=[input_text], images=images, return_tensors="pt", padding=True).to("cuda:0")
52
+
53
+ # with torch.no_grad():
54
+ # outputs = idefics2_model.generate(**inputs, max_length=max_tokens, max_new_tokens=max_tokens, temperature=temperature)
55
+
56
+ # predictions = idefics2_processor.decode(outputs[0], skip_special_tokens=True)
57
+
58
+ # return predictions
59
+
60
+ # Function to predict using LLaVA
61
+ def predict_llava(image, question, temperature, max_tokens):
62
+ # Save the image temporarily
63
+ image.save("temp_image.jpg")
64
+
65
+ # Setup evaluation arguments
66
+ args = Args(
67
+ model_path=llava_model_path,
68
+ model_base=None,
69
+ model_name=get_model_name_from_path(llava_model_path),
70
+ query=question,
71
+ image_path="temp_image.jpg",
72
+ conv_mode=None,
73
+ image_file="temp_image.jpg",
74
+ sep=",",
75
+ temperature=temperature,
76
+ top_p=None,
77
+ num_beams=1,
78
+ max_new_tokens=max_tokens
79
+ )
80
+
81
+ # Generate the answer using the selected model
82
+ output = eval_model(args)
83
+
84
+ return output
85
+
86
+ # Function to predict using LLaVA-Med
87
+ def predict_llava_med(image, question, temperature, max_tokens):
88
+ # Save the image temporarily
89
+ image_path = "temp_image_med.jpg"
90
+ image.save(image_path)
91
+
92
+ # Command to run the LLaVA-Med model
93
+ command = [
94
+ "python", "-m", "llava.eval.run_llava",
95
+ "--model-name", llava_med_model_path,
96
+ "--image-file", image_path,
97
+ "--query", question,
98
+ "--temperature", str(temperature),
99
+ "--max-new-tokens", str(max_tokens)
100
+ ]
101
+
102
+ # Execute the command and capture the output
103
+ result = subprocess.run(command, capture_output=True, text=True)
104
+
105
+ return result.stdout.strip() # Return the output as text
106
+
107
+ # Main prediction function
108
+ def predict(model_name, image, text, temperature, max_tokens):
109
+ if model_name == "LLaVA":
110
+ return predict_llava(image, text, temperature, max_tokens)
111
+ elif model_name == "LLaVA-Med":
112
+ return predict_llava_med(image, text, temperature, max_tokens)
113
+
114
+ # Define the Gradio interface
115
+ interface = gr.Interface(
116
+ fn=predict,
117
+ inputs=[
118
+ gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"),
119
+ gr.Image(type="pil", label="Input Image"),
120
+ gr.Textbox(label="Input Text"),
121
+ gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"),
122
+ gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"),
123
+ ],
124
+ outputs=gr.Textbox(label="Output Text"),
125
+ title="Multimodal LLM Interface",
126
+ description="Switch between models and adjust parameters.",
127
+ )
128
+
129
+ # Launch the Gradio interface
130
+ interface.launch()