Spaces:
Runtime error
Runtime error
| import os | |
| import csv | |
| import json | |
| import torch | |
| import argparse | |
| import pandas as pd | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
| from torch.utils.data import DataLoader | |
| from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
| from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
| from peft import LoraConfig, get_peft_model | |
| from data_utils.xgpt3_dataset import MultiModalDataset | |
| from utils import batchify | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from entailment_inference import get_scores | |
| from nle_inference import VideoCaptionDataset, get_nle | |
| import re | |
| def modify_keys(state_dict): | |
| new_state_dict = defaultdict() | |
| pattern = re.compile(r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj).weight') | |
| for key, value in state_dict.items(): | |
| if pattern.match(key): | |
| key = key.split('.') | |
| key.insert(-1, 'base_layer') | |
| key = '.'.join(key) | |
| new_state_dict[key] = value | |
| return new_state_dict | |
| pretrained_ckpt = "MAGAer13/mplug-owl-llama-7b-video" | |
| trained_ckpt = hf_hub_download(repo_id="videocon/owl-con", filename="pytorch_model.bin", repo_type="model") | |
| tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt) | |
| image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) | |
| processor = MplugOwlProcessor(image_processor, tokenizer) | |
| # Instantiate model | |
| model = MplugOwlForConditionalGeneration.from_pretrained( | |
| pretrained_ckpt, | |
| torch_dtype=torch.bfloat16, | |
| device_map={'': 'cpu'} | |
| ) | |
| peft_config = LoraConfig( | |
| target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', | |
| inference_mode=True, | |
| r=32, | |
| lora_alpha=16, | |
| lora_dropout=0.05 | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| with open(trained_ckpt, 'rb') as f: | |
| ckpt = torch.load(f, map_location = torch.device("cpu")) | |
| ckpt = modify_keys(ckpt) | |
| model.load_state_dict(ckpt) | |
| model = model.to("cuda:0").to(torch.bfloat16) | |
| def inference(videopath, text): | |
| PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
| Human: <|video|> | |
| Human: Does this video entail the description: "{caption}"? | |
| AI: """ | |
| valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential') | |
| dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify) | |
| score = get_scores(model, tokenizer, dataloader) | |
| if score < 0.5: | |
| dataset = VideoCaptionDataset(videopath, text) | |
| dataloader = DataLoader(dataset) | |
| nle = get_nle(model, processor, tokenizer, dataloader) | |
| else: | |
| nle = "None (NLE is only triggered when entailment score < 0.5)" | |
| return score, nle | |
| demo = gr.Interface(inference, | |
| title="Owl-Con Demo", | |
| description="Owl-Con Demo (Code: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)", | |
| inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')], | |
| outputs=[gr.Number(label='Entailment Score'), gr.Textbox(label='Natural Language Explanation')], | |
| examples=[["examples/820.mp4", "We see the group making cookies."], ["examples/820.mp4", "We see the group eating cookies."], ["examples/244.mp4", "She throws a bowling ball while talking on the phone."], ["examples/244.mp4", "She throws a baseball while talking on the phone."]]) | |
| if __name__ == "__main__": | |
| demo.launch() |