File size: 16,826 Bytes
8b992ec
 
 
 
 
 
7265450
8b992ec
b32599e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# !pip install --ignore-installed flask flask-ngrok
# !wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
# !apt update && apt upgrade -y
# !apt-get install p7zip-full -y
# !tar -xvzf ngrok-v3-stable-linux-amd64.tgz
# !./ngrok authtoken YOUR_NGROK_TOKEN
# !pip install -r requirements.txt

import streamlit as st
import tempfile
import os
import time
import re
import numpy as np
import torch
from PIL import Image
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set page configuration
st.set_page_config(page_title="Omni DeepSeek Video Analysis", layout="wide")

# Constants
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# Add CSS for text wrapping and vertical scrollbar for the expander
st.markdown("""
<style>
.output-text {
    white-space: pre-wrap !important;
    word-wrap: break-word !important;
}
.streamlit-expanderContent {
    white-space: pre-wrap !important;
    word-wrap: break-word !important;
    max-height: 100px;  /* 根据需要调整高度 */
    overflow-y: auto;   /* 添加垂直滚动条 */
}
</style>
""", unsafe_allow_html=True)

# Model loading utilities
@st.cache_resource
def load_model_and_tokenizer():
    """Load and cache the model and tokenizer"""
    path = 'AlphaTok/omni-deepseek-v0'
    
    with st.spinner("Loading model (this may take a minute)..."):
        model = AutoModel.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True
        ).eval()
        
        # Move to GPU if available
        if torch.cuda.is_available():
            model = model.cuda()
            st.success("Model loaded on GPU")
        else:
            st.warning("GPU not available, running on CPU (inference will be slow)")
            
        tokenizer = AutoTokenizer.from_pretrained(
            path, 
            trust_remote_code=True, 
            use_fast=False
        )
        
    return model, tokenizer

# Video processing functions
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # Calculate the target aspect ratio
    def find_closest_aspect_ratio(aspect_ratio, target_ratios):
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
        return best_ratio

    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios)
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
    if bound:
        start, end = bound[0], bound[1]
    else:
        start, end = -100000, 100000
    start_idx = max(first_idx, round(start * fps))
    end_idx = min(round(end * fps), max_frame)
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([
        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
        for idx in range(num_segments)
    ])
    return frame_indices

def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    max_frame = len(vr) - 1
    fps = float(vr.get_avg_fps())

    pixel_values_list, num_patches_list = [], []
    transform = build_transform(input_size=input_size)
    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
        img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(tile) for tile in img]
        pixel_values = torch.stack(pixel_values)
        num_patches_list.append(pixel_values.shape[0])
        pixel_values_list.append(pixel_values)
    pixel_values = torch.cat(pixel_values_list)
    return pixel_values, num_patches_list

# Save uploaded file to a temporary location
def save_uploaded_file(uploaded_file):
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp:
        tmp.write(uploaded_file.getvalue())
        return tmp.name

def process_video_and_run_inference(video_path, prompt, model, tokenizer):
    # 加载并预处理视频
    with st.spinner("Processing video..."):
        pixel_values, num_patches_list = load_video(
            video_path, 
            num_segments=16, 
            max_num=1
        )
        if torch.cuda.is_available():
            pixel_values = pixel_values.to(torch.bfloat16).cuda()
        else:
            pixel_values = pixel_values.to(torch.bfloat16)
    
    # 初始化用于文本生成的 streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
    generation_config = dict(max_new_tokens=1024, do_sample=False, streamer=streamer)
    
    # 启动模型对话线程
    thread = Thread(
        target=model.chat, 
        kwargs=dict(
            tokenizer=tokenizer,
            pixel_values=pixel_values,
            question=prompt,
            history=None,
            return_history=False,
            generation_config=generation_config,
        )
    )
    thread.start()
    
    # 用于累积模型原始输出的变量
    raw_output = ""
    
    # 初始化状态变量,用于拆分 think 和 regular 部分
    think_mode = False
    think_content = ""
    regular_content = ""
    
    # 针对每个从 streamer 中获取的文本块进行处理
    for new_text in streamer:
        # 将原始新文本累加到 raw_output 中
        raw_output += new_text
        
        pos = 0
        while pos < len(new_text):
            idx_think = new_text.find("<think>", pos)
            idx_think_close = new_text.find("</think>", pos)
            # 如果本段中没有任何标签,则将剩余内容加入当前模式,并退出循环
            if idx_think == -1 and idx_think_close == -1:
                if think_mode:
                    think_content += new_text[pos:]
                    yield {"type": "think", "content": think_content}
                else:
                    regular_content += new_text[pos:]
                    yield {"type": "regular", "content": regular_content}
                break
            # 如果 <think> 出现得更早或 </think> 不存在
            if idx_think != -1 and (idx_think_close == -1 or idx_think < idx_think_close):
                # 先处理标签前的内容
                if think_mode:
                    think_content += new_text[pos:idx_think]
                    yield {"type": "think", "content": think_content}
                else:
                    regular_content += new_text[pos:idx_think]
                    yield {"type": "regular", "content": regular_content}
                pos = idx_think + len("<think>")
                think_mode = True
            else:
                # 处理 </think> 出现的情况
                if think_mode:
                    think_content += new_text[pos:idx_think_close]
                    yield {"type": "think", "content": think_content}
                    think_content = ""  # 清空 think 内容缓存
                else:
                    regular_content += new_text[pos:idx_think_close]
                    yield {"type": "regular", "content": regular_content}
                pos = idx_think_close + len("</think>")
                think_mode = False

    thread.join()  # 确保线程结束

    # 在终端打印完整的模型原始输出
    print("Complete raw model output:")
    print(raw_output)

# Main app function
def main():
    st.title("Video Analysis with Omni DeepSeek")
    st.markdown("Upload a video and provide a prompt to analyze it.")
    
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer()
    
    # Sidebar for inputs
    with st.sidebar:
        st.header("Upload and Settings")
        video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
        # 添加提示词模板选择,下拉框中包含默认模板和omni-matrix模板
        template_option = st.selectbox("Select Prompt Template", options=["Default", "Omni-Matrix Template"])
        if template_option == "Default":
            prompt = st.text_area("Enter your prompt", value="Please describe this video", height=100)
        else:
            prompt = st.text_area("Enter your prompt", value=f"""
Analyze the video and determine whether the user requires assistance based on the video activity type and behavior. Generate the output in the following structured JSON format:

1. **help_needed**: A boolean value (true or false) indicating whether the user needs help based on the video content.
2. **video_description**: A brief description of the video content.
3. **video_type**: The type of activity in the video. Options include working, meeting, coding, gaming, watching, or other.
4. **function_call_name**: If help_needed is true, specify the name of the function to provide assistance. Options include draft_copy (drafting a copy), assist_coding (coding assistance), web_search (web search). If no help is needed, return an empty string.
5. **function_call_parameters**: If help is needed, provide the required parameters for the function call; otherwise, return an empty array. The parameters are defined as follows:
    - **draft_copy**: Two strings - the first one is the copy subject and the second one is the copy content.
        -- copy_subject(str): The subject of the copy
        -- copy_content(str): The content of the copy
    - **web_search**: 
        -- web_search_content(str): A single string containing the search query.
    - **assist_coding**:
        -- coding_subject(str): The subject of the code
        -- coding_content(str): The content of the code

**Input Requirements:**
The input is a description of the video, and the model needs to analyze it to determine user behavior and generate a JSON response in the following format:

json
{{
    "help_needed": true/false,
    "video_description": "Brief description of the video content",
    "video_type": "working"/"meeting"/"coding"/"gaming"/"watching"/"other",
    "function_call_name": "draft_email/assist_coding/web_search",
    "function_call_parameters": {{
        "parameter1":"parameter1 content", 
        "parameter2":"parameter2 content"
    }}
}}

**Examples:**
1. If the video shows the user debugging code and repeatedly checking documentation:
json
{{
    "help_needed": true,
    "video_description": "The user is debugging code and may need assistance.",
    "video_type": "coding",
    "function_call_name": "assist_coding",
    "function_call_parameters": {{
        "coding_subject": "Help the user implement quicksort.",
        "coding_content": "
        def quicksort(arr):
            if len(arr) <= 1:
                return arr

            pivot = arr[len(arr) // 2]
            left = [x for x in arr if x < pivot]
            middle = [x for x in arr if x == pivot]
            right = [x for x in arr if x > pivot]
            return quicksort(left) + middle + quicksort(right)
            "
    }}
}}

2. If the video shows the user watching a movie and no assistance is required:
json
{{
    "help_needed": false,
    "video_description": "The user is watching a movie.",
    "video_type": "watching",
    "function_call_name": "",
    "function_call_parameters": []
}}

3. If the video shows the user writing an email and might need assistance drafting it:
json
{{
    "help_needed": true,
    "video_description": "The user is writing an email and may need assistance.",
    "video_type": "working",
    "function_call_name": "draft_copy",
    "function_call_parameters": {{
        "copy_subject": "Follow-up Meeting", 
        "copy_content": "Please confirm your availability for the next meeting."
    }}
}}

4. If the video shows the user searching for a specific topic online:
json
{{
    "help_needed": true,
    "video_description": "The user is searching for information online.",
    "video_type": "working",
    "function_call_name": "web_search",
    "function_call_parameters": {{
        "web_search_content": "latest AI research papers"
    }}
}}
""", height=400)
        run_button = st.button("Analyze Video", type="primary")
        
        st.markdown("---")
        st.markdown("### Model Information")
        st.info("Using AlphaTok/omni-deepseek-v0 model")
    
    # Main content area with two columns
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.header("Input")
        if video_file:
            st.video(video_file)
            st.text(f"Prompt: {prompt}")
    
    with col2:
        st.header("Output")
        # 将 thinking 折叠框默认展开
        thinking_container = st.expander("Thinking Process", expanded=True)
        output_container = st.container()
    
    if run_button and video_file and prompt:
        # Save the uploaded video
        video_path = save_uploaded_file(video_file)
        
        # Create a progress bar
        progress_bar = st.progress(0.0)
        
        # Placeholders for streaming output
        thinking_placeholder = thinking_container.empty()
        output_placeholder = output_container.empty()
        
        try:
            progress_step = 0
            # 在流式输出过程中将进度条固定显示在 90%
            for result in process_video_and_run_inference(video_path, prompt, model, tokenizer):
                progress_step += 1
                progress_bar.progress(min(0.9, progress_step / 1024))
                if result["type"] == "think":
                    thinking_placeholder.markdown(f"""<div class="output-text">{result['content']}</div>""", unsafe_allow_html=True)
                elif result["type"] == "regular":
                    content = result["content"]
                    if re.search(r'```\s*json\s*\{', content):
                        json_content = re.search(r'```\s*json\s*(\{.*?\})\s*```', content, re.DOTALL)
                        if json_content:
                            output_placeholder.json(json_content.group(1))
                        else:
                            output_placeholder.markdown(f"""<div class="output-text">{content}</div>""", unsafe_allow_html=True)
                    else:
                        output_placeholder.markdown(f"""<div class="output-text">{content}</div>""", unsafe_allow_html=True)
            
            # 模型生成结束后完成进度条更新
            progress_bar.progress(1.0)
            time.sleep(0.5)
            progress_bar.empty()
            os.unlink(video_path)
            
        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            if os.path.exists(video_path):
                os.unlink(video_path)

if __name__ == "__main__":
    main()