cweigendev commited on
Commit
842b83a
Β·
verified Β·
1 Parent(s): db96b99
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+
6
+ # Clone the model if not already present
7
+ if not os.path.exists("VideoLLaMA3-7B"):
8
+ os.system("apt-get update && apt-get install -y git git-lfs && git lfs install")
9
+ os.system("git clone https://huggingface.co/DAMO-NLP-SG/VideoLLaMA3-7B")
10
+
11
+ # Load model and processor from the local clone
12
+ model_path = "./VideoLLaMA3-7B"
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_path,
16
+ trust_remote_code=True,
17
+ device_map="auto",
18
+ torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ )
21
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
22
+
23
+ def describe_video(video, question):
24
+ conversation = [
25
+ {"role": "system", "content": "You are a helpful assistant."},
26
+ {
27
+ "role": "user",
28
+ "content": [
29
+ {"type": "video", "video": {"video_path": video, "fps": 1, "max_frames": 128}},
30
+ {"type": "text", "text": question},
31
+ ]
32
+ },
33
+ ]
34
+ inputs = processor(conversation=conversation, return_tensors="pt")
35
+ inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
36
+ if "pixel_values" in inputs:
37
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
38
+ output_ids = model.generate(**inputs, max_new_tokens=128)
39
+ return processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
40
+
41
+ # Gradio UI
42
+ demo = gr.Interface(
43
+ fn=describe_video,
44
+ inputs=[
45
+ gr.Video(label="Upload a video"),
46
+ gr.Textbox(label="Question", value="Describe this video in detail."),
47
+ ],
48
+ outputs=gr.Textbox(label="Response"),
49
+ )
50
+
51
+ demo.launch()