# Install FlashAttention import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) # Main code import random import threading from typing import Generator import warnings warnings.simplefilter("ignore") from PIL import Image import gradio as gr import numpy as np import spaces import torch from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_REPO = "yuki-imajuku/phi-3v-rakuten-recipe-all-finetune-mq" MAX_LENGTH = 2048 MAX_SEED = np.iinfo(np.int32).max if DEVICE == "cpu": model = None else: processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) model = AutoModelForCausalLM.from_pretrained( MODEL_REPO, device_map=DEVICE, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, _attn_implementation="flash_attention_2", ) @spaces.GPU @torch.inference_mode def inference_fn( image: Image.Image | None, food_name: str, temperature: float = 0.0, top_p: float = 0.9, num_beams: int = 1, seed: int = 0, randomize_seed: bool = False, ) -> Generator[tuple[str, int], None, tuple[str, int]]: if image is None: gr.Warning("Please upload an image!", duration=10) yield "Please upload an image!", seed return "Please upload an image!", seed if model is None: gr.Warning("Please run this demo on a GPU instance!", duration=10) yield "Please run this demo on a GPU instance!", seed return "Please run this demo on a GPU instance!", seed first_message = "この写真の料理のレシピを教えてください。" if len(food_name) > 0: first_message = f"この写真の料理は{food_name}です。レシピを教えてください。" messages = [ {"role": "user", "content": f"<|image_1|>\n{first_message}"}, ] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor(prompt, [image], return_tensors="pt").to(DEVICE) if randomize_seed: seed = random.randint(0, MAX_SEED) torch.manual_seed(seed) generation_kwargs = dict( inputs, eos_token_id=processor.tokenizer.eos_token_id, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=MAX_LENGTH, use_cache=True, streamer=streamer, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text, seed return generated_text, seed DESCRIPTION = """# FoodMLLM-JP Demo Space [![arXiv](https://img.shields.io/badge/arXiv-2409.18459-b31b1b.svg)](https://arxiv.org/abs/2409.18459) This demo is for research and development purposes only. It is not suitable for commercial use or in environments where failure could cause significant harm. Differences in environment, library versions, etc. may result in output that differs from that reported in the paper. Please upload an image of a dish and the model will generate a recipe text based on the image. """ with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", image_mode="RGB", type="pil") input_button = gr.Button(value="Submit") with gr.Accordion("Advanced Options", open=False): food_name = gr.Textbox(label="Food Name", placeholder="食事名を入れてください。(省略可)", show_label=False, value="", visible=True) temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=0.0) top_p = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=1.0) # Streamer does not support beam search (num_beams > 1) yet. Invisible until it is supported. num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=10, step=1, value=1, visible=False) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) with gr.Column(): output = gr.Textbox(label="Recipe Text") input_button.click( fn=inference_fn, inputs=[input_image, food_name, temperature, top_p, num_beams, seed, randomize_seed], outputs=[output, seed], ) img2txt_examples = gr.Examples( examples=[["examples/example1.jpg"], ["examples/example2.jpg"], ["examples/example3.jpg"], ["examples/example4.jpg"]], fn=inference_fn, inputs=[input_image], outputs=[output, seed], cache_examples=False, ) demo.queue().launch()