import json import os from transformers import AutoProcessor, AutoModelForVision2Seq import torch from PIL import Image import gradio as gr import subprocess from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path from llava.eval.run_llava import eval_model os.chdir("My_new_LLaVA/LLaVA") # Update this if needed # Verify the current working directory print("Current Working Directory:", os.getcwd()) # Load the LLaVA model and processor llava_model_path = "/My_new_LLaVA/llava-fine_tune_model" # Load the LLaVA-Med model and processor llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model" # Args class to store arguments for LLaVA models class Args: def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens): self.model_path = model_path self.model_base = model_base self.model_name = model_name self.query = query self.image_path = image_path self.conv_mode = conv_mode self.image_file = image_file self.sep = sep self.temperature = temperature self.top_p = top_p self.num_beams = num_beams self.max_new_tokens = max_new_tokens # Function to predict using LLaVA def predict_llava(image, question, temperature, max_tokens): # Save the image temporarily image.save("temp_image.jpg") # Setup evaluation arguments args = Args( model_path=llava_model_path, model_base=None, model_name=get_model_name_from_path(llava_model_path), query=question, image_path="temp_image.jpg", conv_mode=None, image_file="temp_image.jpg", sep=",", temperature=temperature, top_p=None, num_beams=1, max_new_tokens=max_tokens ) # Generate the answer using the selected model output = eval_model(args) return output # Function to predict using LLaVA-Med def predict_llava_med(image, question, temperature, max_tokens): # Save the image temporarily image_path = "temp_image_med.jpg" image.save(image_path) # Command to run the LLaVA-Med model command = [ "python", "-m", "llava.eval.run_llava", "--model-name", llava_med_model_path, "--image-file", image_path, "--query", question, "--temperature", str(temperature), "--max-new-tokens", str(max_tokens) ] # Execute the command and capture the output result = subprocess.run(command, capture_output=True, text=True) return result.stdout.strip() # Return the output as text # Main prediction function def predict(model_name, image, text, temperature, max_tokens): if model_name == "LLaVA": return predict_llava(image, text, temperature, max_tokens) elif model_name == "LLaVA-Med": return predict_llava_med(image, text, temperature, max_tokens) # Define the Gradio interface interface = gr.Interface( fn=predict, inputs=[ gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"), gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Input Text"), gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"), gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"), ], outputs=gr.Textbox(label="Output Text"), title="Multimodal LLM Interface", description="Switch between models and adjust parameters.", ) # Launch the Gradio interface interface.launch()