{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "XICISQU4VQ7j" }, "outputs": [], "source": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "VsFOxVleVRpA" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import json\n", "import os\n", "from transformers import AutoProcessor, AutoModelForVision2Seq\n", "import torch\n", "from PIL import Image\n", "import gradio as gr\n", "import subprocess\n", "from llava.model.builder import load_pretrained_model\n", "from llava.mm_utils import get_model_name_from_path\n", "from llava.eval.run_llava import eval_model\n", "\n", "# Load the LLaVA model and processor\n", "llava_model_path = \"/workspace/LLaVA/LLaVA/llava-fine_tune_model\"\n", "\n", "# Load the LLaVA-Med model and processor\n", "llava_med_model_path = \"/workspace/LLaVA-Med/Model/fine_tuned-med-llava\"\n", "\n", "# Args class to store arguments for LLaVA models\n", "class Args:\n", " def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):\n", " self.model_path = model_path\n", " self.model_base = model_base\n", " self.model_name = model_name\n", " self.query = query\n", " self.image_path = image_path\n", " self.conv_mode = conv_mode\n", " self.image_file = image_file\n", " self.sep = sep\n", " self.temperature = temperature\n", " self.top_p = top_p\n", " self.num_beams = num_beams\n", " self.max_new_tokens = max_new_tokens\n", "\n", "# Function to predict using Idefics2\n", "def predict_idefics2(image, question, temperature, max_tokens):\n", " image = image.convert(\"RGB\")\n", " images = [image]\n", "\n", " messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"image\"},\n", " {\"type\": \"text\", \"text\": question}\n", " ]\n", " }\n", " ]\n", " input_text = idefics2_processor.apply_chat_template(messages, add_generation_prompt=False).strip()\n", "\n", " inputs = idefics2_processor(text=[input_text], images=images, return_tensors=\"pt\", padding=True).to(\"cuda:0\")\n", "\n", " with torch.no_grad():\n", " outputs = idefics2_model.generate(**inputs, max_length=max_tokens, max_new_tokens=max_tokens, temperature=temperature)\n", "\n", " predictions = idefics2_processor.decode(outputs[0], skip_special_tokens=True)\n", "\n", " return predictions\n", "\n", "# Function to predict using LLaVA\n", "def predict_llava(image, question, temperature, max_tokens):\n", " # Save the image temporarily\n", " image.save(\"temp_image.jpg\")\n", "\n", " # Setup evaluation arguments\n", " args = Args(\n", " model_path=llava_model_path,\n", " model_base=None,\n", " model_name=get_model_name_from_path(llava_model_path),\n", " query=question,\n", " image_path=\"temp_image.jpg\",\n", " conv_mode=None,\n", " image_file=\"temp_image.jpg\",\n", " sep=\",\",\n", " temperature=temperature,\n", " top_p=None,\n", " num_beams=1,\n", " max_new_tokens=max_tokens\n", " )\n", "\n", " # Generate the answer using the selected model\n", " output = eval_model(args)\n", "\n", " return output\n", "\n", "# Function to predict using LLaVA-Med\n", "def predict_llava_med(image, question, temperature, max_tokens):\n", " # Save the image temporarily\n", " image_path = \"temp_image_med.jpg\"\n", " image.save(image_path)\n", "\n", " # Command to run the LLaVA-Med model\n", " command = [\n", " \"python\", \"-m\", \"llava.eval.run_llava\",\n", " \"--model-name\", llava_med_model_path,\n", " \"--image-file\", image_path,\n", " \"--query\", question,\n", " \"--temperature\", str(temperature),\n", " \"--max-new-tokens\", str(max_tokens)\n", " ]\n", "\n", " # Execute the command and capture the output\n", " result = subprocess.run(command, capture_output=True, text=True)\n", "\n", " return result.stdout.strip() # Return the output as text\n", "\n", "# Main prediction function\n", "def predict(model_name, image, text, temperature, max_tokens):\n", " if model_name == \"LLaVA\":\n", " return predict_llava(image, text, temperature, max_tokens)\n", " elif model_name == \"LLaVA-Med\":\n", " return predict_llava_med(image, text, temperature, max_tokens)\n", "\n", "# Define the Gradio interface\n", "interface = gr.Interface(\n", " fn=predict,\n", " inputs=[\n", " gr.Radio(choices=[\"LLaVA\", \"LLaVA-Med\"], label=\"Select Model\"),\n", " gr.Image(type=\"pil\", label=\"Input Image\"),\n", " gr.Textbox(label=\"Input Text\"),\n", " gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label=\"Temperature\"),\n", " gr.Slider(minimum=1, maximum=512, default=256, label=\"Max Tokens\"),\n", " ],\n", " outputs=gr.Textbox(label=\"Output Text\"),\n", " title=\"Multimodal LLM Interface\",\n", " description=\"Switch between models and adjust parameters.\",\n", ")\n", "\n", "# Launch the Gradio interface\n", "interface.launch()\n" ], "metadata": { "id": "pCJxQjryVRrh" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "YBSsgQNwVRto" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "UjB_xxubVRu7" }, "execution_count": null, "outputs": [] } ] }