File size: 3,578 Bytes
0abfefa
 
 
 
 
 
 
 
0ce6e9d
 
 
0abfefa
435eee3
 
 
 
 
 
0abfefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

import json
import os
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from PIL import Image
import gradio as gr
import subprocess
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

os.chdir("My_new_LLaVA/LLaVA")  # Update this if needed

# Verify the current working directory
print("Current Working Directory:", os.getcwd())


# Load the LLaVA model and processor
llava_model_path = "/My_new_LLaVA/llava-fine_tune_model"

# Load the LLaVA-Med model and processor
llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model"

# Args class to store arguments for LLaVA models
class Args:
    def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):
        self.model_path = model_path
        self.model_base = model_base
        self.model_name = model_name
        self.query = query
        self.image_path = image_path
        self.conv_mode = conv_mode
        self.image_file = image_file
        self.sep = sep
        self.temperature = temperature
        self.top_p = top_p
        self.num_beams = num_beams
        self.max_new_tokens = max_new_tokens


# Function to predict using LLaVA
def predict_llava(image, question, temperature, max_tokens):
    # Save the image temporarily
    image.save("temp_image.jpg")

    # Setup evaluation arguments
    args = Args(
        model_path=llava_model_path,
        model_base=None,
        model_name=get_model_name_from_path(llava_model_path),
        query=question,
        image_path="temp_image.jpg",
        conv_mode=None,
        image_file="temp_image.jpg",
        sep=",",
        temperature=temperature,
        top_p=None,
        num_beams=1,
        max_new_tokens=max_tokens
    )

    # Generate the answer using the selected model
    output = eval_model(args)

    return output

# Function to predict using LLaVA-Med
def predict_llava_med(image, question, temperature, max_tokens):
    # Save the image temporarily
    image_path = "temp_image_med.jpg"
    image.save(image_path)

    # Command to run the LLaVA-Med model
    command = [
        "python", "-m", "llava.eval.run_llava",
        "--model-name", llava_med_model_path,
        "--image-file", image_path,
        "--query", question,
        "--temperature", str(temperature),
        "--max-new-tokens", str(max_tokens)
    ]

    # Execute the command and capture the output
    result = subprocess.run(command, capture_output=True, text=True)

    return result.stdout.strip()  # Return the output as text

# Main prediction function
def predict(model_name, image, text, temperature, max_tokens):
    if model_name == "LLaVA":
        return predict_llava(image, text, temperature, max_tokens)
    elif model_name == "LLaVA-Med":
        return predict_llava_med(image, text, temperature, max_tokens)

# Define the Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"),
        gr.Image(type="pil", label="Input Image"),
        gr.Textbox(label="Input Text"),
        gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"),
        gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"),
    ],
    outputs=gr.Textbox(label="Output Text"),
    title="Multimodal LLM Interface",
    description="Switch between models and adjust parameters.",
)

# Launch the Gradio interface
interface.launch()