Spaces:
Runtime error
Runtime error
| # from .demo_modelpart import InferenceDemo | |
| import gradio as gr | |
| import os | |
| from threading import Thread | |
| # import time | |
| import cv2 | |
| import datetime | |
| # import copy | |
| import torch | |
| import spaces | |
| import numpy as np | |
| from llava import conversation as conversation_lib | |
| from llava.constants import DEFAULT_IMAGE_TOKEN | |
| from llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| from llava.conversation import conv_templates, SeparatorStyle | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import ( | |
| tokenizer_image_token, | |
| get_model_name_from_path, | |
| KeywordsStoppingCriteria, | |
| ) | |
| from serve_constants import html_header | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| from transformers import TextStreamer, TextIteratorStreamer | |
| import hashlib | |
| import PIL | |
| import base64 | |
| import json | |
| import datetime | |
| import gradio as gr | |
| import gradio_client | |
| import subprocess | |
| import sys | |
| external_log_dir = "./logs" | |
| LOGDIR = external_log_dir | |
| def install_gradio_4_35_0(): | |
| current_version = gr.__version__ | |
| if current_version != "4.35.0": | |
| print(f"Current Gradio version: {current_version}") | |
| print("Installing Gradio 4.35.0...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"]) | |
| print("Gradio 4.35.0 installed successfully.") | |
| else: | |
| print("Gradio 4.35.0 is already installed.") | |
| # Call the function to install Gradio 4.35.0 if needed | |
| install_gradio_4_35_0() | |
| import gradio as gr | |
| import gradio_client | |
| print(f"Gradio version: {gr.__version__}") | |
| print(f"Gradio-client version: {gradio_client.__version__}") | |
| def get_conv_log_filename(): | |
| t = datetime.datetime.now() | |
| name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") | |
| return name | |
| class InferenceDemo(object): | |
| def __init__( | |
| self, args, model_path, tokenizer, model, image_processor, context_len | |
| ) -> None: | |
| disable_torch_init() | |
| self.tokenizer, self.model, self.image_processor, self.context_len = ( | |
| tokenizer, | |
| model, | |
| image_processor, | |
| context_len, | |
| ) | |
| if "llama-2" in model_name.lower(): | |
| conv_mode = "llava_llama_2" | |
| elif "v1" in model_name.lower(): | |
| conv_mode = "llava_v1" | |
| elif "mpt" in model_name.lower(): | |
| conv_mode = "mpt" | |
| elif "qwen" in model_name.lower(): | |
| conv_mode = "qwen_1_5" | |
| elif "pangea" in model_name.lower(): | |
| conv_mode = "qwen_1_5" | |
| else: | |
| conv_mode = "llava_v0" | |
| if args.conv_mode is not None and conv_mode != args.conv_mode: | |
| print( | |
| "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( | |
| conv_mode, args.conv_mode, args.conv_mode | |
| ) | |
| ) | |
| else: | |
| args.conv_mode = conv_mode | |
| self.conv_mode = conv_mode | |
| self.conversation = conv_templates[args.conv_mode].copy() | |
| self.num_frames = args.num_frames | |
| def is_valid_video_filename(name): | |
| video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] | |
| ext = name.split(".")[-1].lower() | |
| if ext in video_extensions: | |
| return True | |
| else: | |
| return False | |
| def is_valid_image_filename(name): | |
| image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] | |
| ext = name.split(".")[-1].lower() | |
| if ext in image_extensions: | |
| return True | |
| else: | |
| return False | |
| def sample_frames(video_file, num_frames): | |
| video = cv2.VideoCapture(video_file) | |
| total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| interval = total_frames // num_frames | |
| frames = [] | |
| for i in range(total_frames): | |
| ret, frame = video.read() | |
| pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| if not ret: | |
| continue | |
| if i % interval == 0: | |
| frames.append(pil_img) | |
| video.release() | |
| return frames | |
| def load_image(image_file): | |
| if image_file.startswith("http") or image_file.startswith("https"): | |
| response = requests.get(image_file) | |
| if response.status_code == 200: | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| else: | |
| print("failed to load the image") | |
| else: | |
| print("Load image from local file") | |
| print(image_file) | |
| image = Image.open(image_file).convert("RGB") | |
| return image | |
| def clear_history(history): | |
| our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy() | |
| return None | |
| def clear_response(history): | |
| for index_conv in range(1, len(history)): | |
| # loop until get a text response from our model. | |
| conv = history[-index_conv] | |
| if not (conv[0] is None): | |
| break | |
| question = history[-index_conv][0] | |
| history = history[:-index_conv] | |
| return history, question | |
| # def print_like_dislike(x: gr.LikeData): | |
| # print(x.index, x.value, x.liked) | |
| def add_message(history, message): | |
| # history=[] | |
| global our_chatbot | |
| if len(history) == 0: | |
| our_chatbot = InferenceDemo( | |
| args, model_path, tokenizer, model, image_processor, context_len | |
| ) | |
| for x in message["files"]: | |
| history.append(((x,), None)) | |
| if message["text"] is not None: | |
| history.append((message["text"], None)) | |
| return history, gr.MultimodalTextbox(value=None, interactive=False) | |
| def bot(history): | |
| text = history[-1][0] | |
| images_this_term = [] | |
| text_this_term = "" | |
| # import pdb;pdb.set_trace() | |
| num_new_images = 0 | |
| for i, message in enumerate(history[:-1]): | |
| if type(message[0]) is tuple: | |
| images_this_term.append(message[0][0]) | |
| if is_valid_video_filename(message[0][0]): | |
| # 不接受视频 | |
| raise ValueError("Video is not supported") | |
| num_new_images += our_chatbot.num_frames | |
| elif is_valid_image_filename(message[0][0]): | |
| print("#### Load image from local file",message[0][0]) | |
| num_new_images += 1 | |
| else: | |
| raise ValueError("Invalid image file") | |
| else: | |
| num_new_images = 0 | |
| # for message in history[-i-1:]: | |
| # images_this_term.append(message[0][0]) | |
| assert len(images_this_term) > 0, "must have an image" | |
| # image_files = (args.image_file).split(',') | |
| # image = [load_image(f) for f in images_this_term if f] | |
| all_image_hash = [] | |
| all_image_path = [] | |
| for image_path in images_this_term: | |
| with open(image_path, "rb") as image_file: | |
| image_data = image_file.read() | |
| image_hash = hashlib.md5(image_data).hexdigest() | |
| all_image_hash.append(image_hash) | |
| image = PIL.Image.open(image_path).convert("RGB") | |
| t = datetime.datetime.now() | |
| filename = os.path.join( | |
| LOGDIR, | |
| "serve_images", | |
| f"{t.year}-{t.month:02d}-{t.day:02d}", | |
| f"{image_hash}.jpg", | |
| ) | |
| all_image_path.append(filename) | |
| if not os.path.isfile(filename): | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| print("image save to",filename) | |
| image.save(filename) | |
| image_list = [] | |
| for f in images_this_term: | |
| if is_valid_video_filename(f): | |
| image_list += sample_frames(f, our_chatbot.num_frames) | |
| elif is_valid_image_filename(f): | |
| image_list.append(load_image(f)) | |
| else: | |
| raise ValueError("Invalid image file") | |
| image_tensor = [ | |
| our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][ | |
| 0 | |
| ] | |
| .half() | |
| .to(our_chatbot.model.device) | |
| for f in image_list | |
| ] | |
| image_tensor = torch.stack(image_tensor) | |
| image_token = DEFAULT_IMAGE_TOKEN * num_new_images | |
| # if our_chatbot.model.config.mm_use_im_start_end: | |
| # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp | |
| # else: | |
| inp = text | |
| inp = image_token + "\n" + inp | |
| our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) | |
| # image = None | |
| our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) | |
| prompt = our_chatbot.conversation.get_prompt() | |
| # input_ids = ( | |
| # tokenizer_image_token( | |
| # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| # ) | |
| # .unsqueeze(0) | |
| # .to(our_chatbot.model.device) | |
| # ) | |
| input_ids = tokenizer_image_token( | |
| prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(our_chatbot.model.device) | |
| # print("### input_id",input_ids) | |
| stop_str = ( | |
| our_chatbot.conversation.sep | |
| if our_chatbot.conversation.sep_style != SeparatorStyle.TWO | |
| else our_chatbot.conversation.sep2 | |
| ) | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria( | |
| keywords, our_chatbot.tokenizer, input_ids | |
| ) | |
| # streamer = TextStreamer( | |
| # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True | |
| # ) | |
| streamer = TextIteratorStreamer( | |
| our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| print(our_chatbot.model.device) | |
| print(input_ids.device) | |
| print(image_tensor.device) | |
| # with torch.inference_mode(): | |
| # output_ids = our_chatbot.model.generate( | |
| # input_ids, | |
| # images=image_tensor, | |
| # do_sample=True, | |
| # temperature=0.7, | |
| # top_p=1.0, | |
| # max_new_tokens=4096, | |
| # streamer=streamer, | |
| # use_cache=False, | |
| # stopping_criteria=[stopping_criteria], | |
| # ) | |
| # outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip() | |
| # if outputs.endswith(stop_str): | |
| # outputs = outputs[: -len(stop_str)] | |
| # our_chatbot.conversation.messages[-1][-1] = outputs | |
| # history[-1] = [text, outputs] | |
| # return history | |
| generate_kwargs = dict( | |
| inputs=input_ids, | |
| streamer=streamer, | |
| images=image_tensor, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.2, | |
| num_beams=1, | |
| use_cache=False, | |
| stopping_criteria=[stopping_criteria], | |
| ) | |
| t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| our_chatbot.conversation.messages[-1][-1] = "".join(outputs) | |
| history[-1] = [text, "".join(outputs)] | |
| yield history | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "type": "chat", | |
| "model": "Pangea-7b", | |
| "state": history, | |
| "images": all_image_hash, | |
| "images_path": all_image_path | |
| } | |
| print("#### conv log",data) | |
| fout.write(json.dumps(data) + "\n") | |
| txt = gr.Textbox( | |
| scale=4, | |
| show_label=False, | |
| placeholder="Enter text and press enter.", | |
| container=False, | |
| ) | |
| with gr.Blocks( | |
| css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}", | |
| ) as demo: | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # gr.Markdown(title_markdown) | |
| gr.HTML(html_header) | |
| with gr.Column(): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750) | |
| with gr.Row(): | |
| upvote_btn = gr.Button(value="👍 Upvote", interactive=True) | |
| downvote_btn = gr.Button(value="👎 Downvote", interactive=True) | |
| flag_btn = gr.Button(value="⚠️ Flag", interactive=True) | |
| # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True) | |
| regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) | |
| clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter message or upload file...", | |
| show_label=False, | |
| submit_btn="🚀" | |
| ) | |
| print(cur_dir) | |
| gr.Examples( | |
| examples_per_page=20, | |
| examples=[ | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/user_example_07.jpg", | |
| ], | |
| "text": "那要我问问你,你这个是什么🐱?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/user_example_05.jpg", | |
| ], | |
| "text": "この猫の目の大きさは、どのような理由で他の猫と比べて特に大きく見えますか?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/172197131626056_P7966202.png", | |
| ], | |
| "text": "Why this image funny?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/norway.jpg", | |
| ], | |
| "text": "Analysieren, in welchem Land diese Szene höchstwahrscheinlich gedreht wurde.", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/totoro.jpg", | |
| ], | |
| "text": "¿En qué anime aparece esta escena? ¿Puedes presentarlo?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/africa.jpg", | |
| ], | |
| "text": "इस तस्वीर में हर एक दृश्य तत्व का क्या प्रतिनिधित्व करता है?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/hot_ballon.jpg", | |
| ], | |
| "text": "ฉากบอลลูนลมร้อนในภาพนี้อาจอยู่ที่ไหน? สถานที่นี้มีความพิเศษอย่างไร?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/bar.jpg", | |
| ], | |
| "text": "Você pode me dar ideias de design baseadas no tema de coquetéis deste letreiro?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/pink_lake.jpg", | |
| ], | |
| "text": "Обясни защо езерото на този остров е в този цвят.", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/hanzi.jpg", | |
| ], | |
| "text": "Can you describe in Hebrew the evolution process of these four Chinese characters from pictographs to modern characters?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/ballon.jpg", | |
| ], | |
| "text": "இந்த காட்சியை விவரிக்கவும், மேலும் இந்த படத்தின் அடிப்படையில் துருக்கியில் இந்த காட்சியுடன் தொடர்பான சில பிரபலமான நிகழ்வுகள் என்ன?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/pie.jpg", | |
| ], | |
| "text": "Décrivez ce graphique. Quelles informations pouvons-nous en tirer?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/camera.jpg", | |
| ], | |
| "text": "Apa arti dari dua angka di sebelah kiri yang ditampilkan di layar kamera?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/dog.jpg", | |
| ], | |
| "text": "이 강아지의 표정을 보고 어떤 기분이나 감정을 느끼고 있는지 설명해 주시겠어요?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/book.jpg", | |
| ], | |
| "text": "What language is the text in, and what does the title mean in English?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/food.jpg", | |
| ], | |
| "text": "Unaweza kunipa kichocheo cha kutengeneza hii pancake?", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/line chart.jpg", | |
| ], | |
| "text": "Hãy trình bày những xu hướng mà bạn quan sát được từ biểu đồ và hiện tượng xã hội tiềm ẩn từ đó.", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/south africa.jpg", | |
| ], | |
| "text": "Waar is hierdie plek? Help my om ’n reisroete vir hierdie land te beplan.", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/girl.jpg", | |
| ], | |
| "text": "لماذا هذه الصورة مضحكة؟", | |
| }, | |
| ], | |
| [ | |
| { | |
| "files": [ | |
| f"{cur_dir}/examples/eagles.jpg", | |
| ], | |
| "text": "Какой креатив должен быть в этом логотипе?", | |
| }, | |
| ], | |
| ], | |
| inputs=[chat_input], | |
| label="Image", | |
| ) | |
| chat_msg = chat_input.submit( | |
| add_message, [chatbot, chat_input], [chatbot, chat_input] | |
| ) | |
| bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response") | |
| bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
| # chatbot.like(print_like_dislike, None, None) | |
| clear_btn.click( | |
| fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all" | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| import argparse | |
| argparser = argparse.ArgumentParser() | |
| argparser.add_argument("--server_name", default="0.0.0.0", type=str) | |
| argparser.add_argument("--port", default="6123", type=str) | |
| argparser.add_argument( | |
| "--model_path", default="neulab/Pangea-7B", type=str | |
| ) | |
| # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m") | |
| argparser.add_argument("--model-base", type=str, default=None) | |
| argparser.add_argument("--num-gpus", type=int, default=1) | |
| argparser.add_argument("--conv-mode", type=str, default=None) | |
| argparser.add_argument("--temperature", type=float, default=0.7) | |
| argparser.add_argument("--max-new-tokens", type=int, default=4096) | |
| argparser.add_argument("--num_frames", type=int, default=16) | |
| argparser.add_argument("--load-8bit", action="store_true") | |
| argparser.add_argument("--load-4bit", action="store_true") | |
| argparser.add_argument("--debug", action="store_true") | |
| args = argparser.parse_args() | |
| model_path = args.model_path | |
| filt_invalid = "cut" | |
| model_name = get_model_name_from_path(args.model_path) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) | |
| model=model.to(torch.device('cuda')) | |
| our_chatbot = None | |
| demo.launch() |