Spaces:
Runtime error
Runtime error
import json | |
import os | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
import torch | |
from PIL import Image | |
import gradio as gr | |
import subprocess | |
from llava.model.builder import load_pretrained_model | |
from llava.mm_utils import get_model_name_from_path | |
from llava.eval.run_llava import eval_model | |
os.chdir("My_new_LLaVA/LLaVA") # Update this if needed | |
# Verify the current working directory | |
print("Current Working Directory:", os.getcwd()) | |
# Load the LLaVA model and processor | |
llava_model_path = "/My_new_LLaVA/llava-fine_tune_model" | |
# Load the LLaVA-Med model and processor | |
llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model" | |
# Args class to store arguments for LLaVA models | |
class Args: | |
def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens): | |
self.model_path = model_path | |
self.model_base = model_base | |
self.model_name = model_name | |
self.query = query | |
self.image_path = image_path | |
self.conv_mode = conv_mode | |
self.image_file = image_file | |
self.sep = sep | |
self.temperature = temperature | |
self.top_p = top_p | |
self.num_beams = num_beams | |
self.max_new_tokens = max_new_tokens | |
# Function to predict using LLaVA | |
def predict_llava(image, question, temperature, max_tokens): | |
# Save the image temporarily | |
image.save("temp_image.jpg") | |
# Setup evaluation arguments | |
args = Args( | |
model_path=llava_model_path, | |
model_base=None, | |
model_name=get_model_name_from_path(llava_model_path), | |
query=question, | |
image_path="temp_image.jpg", | |
conv_mode=None, | |
image_file="temp_image.jpg", | |
sep=",", | |
temperature=temperature, | |
top_p=None, | |
num_beams=1, | |
max_new_tokens=max_tokens | |
) | |
# Generate the answer using the selected model | |
output = eval_model(args) | |
return output | |
# Function to predict using LLaVA-Med | |
def predict_llava_med(image, question, temperature, max_tokens): | |
# Save the image temporarily | |
image_path = "temp_image_med.jpg" | |
image.save(image_path) | |
# Command to run the LLaVA-Med model | |
command = [ | |
"python", "-m", "llava.eval.run_llava", | |
"--model-name", llava_med_model_path, | |
"--image-file", image_path, | |
"--query", question, | |
"--temperature", str(temperature), | |
"--max-new-tokens", str(max_tokens) | |
] | |
# Execute the command and capture the output | |
result = subprocess.run(command, capture_output=True, text=True) | |
return result.stdout.strip() # Return the output as text | |
# Main prediction function | |
def predict(model_name, image, text, temperature, max_tokens): | |
if model_name == "LLaVA": | |
return predict_llava(image, text, temperature, max_tokens) | |
elif model_name == "LLaVA-Med": | |
return predict_llava_med(image, text, temperature, max_tokens) | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"), | |
gr.Image(type="pil", label="Input Image"), | |
gr.Textbox(label="Input Text"), | |
gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"), | |
gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"), | |
], | |
outputs=gr.Textbox(label="Output Text"), | |
title="Multimodal LLM Interface", | |
description="Switch between models and adjust parameters.", | |
) | |
# Launch the Gradio interface | |
interface.launch() | |