John Ho commited on
Commit
d9d1598
·
1 Parent(s): c7e712e

trying model quantization

Browse files
Files changed (2) hide show
  1. app.py +24 -4
  2. pyproject.toml +1 -0
app.py CHANGED
@@ -1,6 +1,11 @@
 
1
  import spaces, ffmpeg, os, sys, torch
2
  import gradio as gr
3
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
 
 
 
4
  from qwen_vl_utils import process_vision_info
5
  from loguru import logger
6
 
@@ -50,8 +55,17 @@ def get_fps_ffmpeg(video_path: str):
50
  def load_model(
51
  model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview",
52
  use_flash_attention: bool = True,
 
53
  ):
54
- # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
 
 
 
 
 
 
 
 
55
  model = (
56
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
  model_name,
@@ -59,6 +73,7 @@ def load_model(
59
  attn_implementation="flash_attention_2",
60
  device_map=DEVICE,
61
  low_cpu_mem_usage=True,
 
62
  )
63
  if use_flash_attention
64
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -66,6 +81,7 @@ def load_model(
66
  torch_dtype=DTYPE,
67
  device_map=DEVICE,
68
  low_cpu_mem_usage=True,
 
69
  )
70
  )
71
  # Set model to evaluation mode for inference (disables dropout, etc.)
@@ -87,10 +103,13 @@ def inference(
87
  video_path: str,
88
  prompt: str = "Describe the camera motion in this video.",
89
  use_flash_attention: bool = True,
 
90
  ):
91
  # default processor
92
  processor = load_processor()
93
- model = load_model(use_flash_attention=use_flash_attention)
 
 
94
 
95
  # The model is trained on 8.0 FPS which we recommend for optimal inference
96
  fps = get_fps_ffmpeg(video_path)
@@ -149,7 +168,8 @@ demo = gr.Interface(
149
  inputs=[
150
  gr.Video(label="Input Video"),
151
  gr.Textbox(label="Prompt", value="Describe the camera motion in this video."),
152
- gr.Checkbox(label="Use Flash Attention", value=True),
 
153
  ],
154
  outputs=gr.JSON(label="Output JSON"),
155
  title="",
 
1
+ from statistics import quantiles
2
  import spaces, ffmpeg, os, sys, torch
3
  import gradio as gr
4
+ from transformers import (
5
+ Qwen2_5_VLForConditionalGeneration,
6
+ AutoProcessor,
7
+ BitsAndBytesConfig,
8
+ )
9
  from qwen_vl_utils import process_vision_info
10
  from loguru import logger
11
 
 
55
  def load_model(
56
  model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview",
57
  use_flash_attention: bool = True,
58
+ apply_quantization: bool = True,
59
  ):
60
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving,
61
+ # especially in multi-image and video scenarios.
62
+ bnb_config = BitsAndBytesConfig(
63
+ load_in_4bit=True, # Load model weights in 4-bit
64
+ bnb_4bit_quant_type="nf4", # Use NF4 quantization (or "fp4")
65
+ bnb_4bit_compute_dtype=DTYPE, # Perform computations in bfloat16/float16
66
+ bnb_4bit_use_double_quant=True, # Optional: further quantization for slightly more memory saving
67
+ )
68
+
69
  model = (
70
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
71
  model_name,
 
73
  attn_implementation="flash_attention_2",
74
  device_map=DEVICE,
75
  low_cpu_mem_usage=True,
76
+ quantization_config=bnb_config if apply_quantization else None,
77
  )
78
  if use_flash_attention
79
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
81
  torch_dtype=DTYPE,
82
  device_map=DEVICE,
83
  low_cpu_mem_usage=True,
84
+ quantization_config=bnb_config if apply_quantization else None,
85
  )
86
  )
87
  # Set model to evaluation mode for inference (disables dropout, etc.)
 
103
  video_path: str,
104
  prompt: str = "Describe the camera motion in this video.",
105
  use_flash_attention: bool = True,
106
+ apply_quantization: bool = True,
107
  ):
108
  # default processor
109
  processor = load_processor()
110
+ model = load_model(
111
+ use_flash_attention=use_flash_attention, apply_quantization=apply_quantization
112
+ )
113
 
114
  # The model is trained on 8.0 FPS which we recommend for optimal inference
115
  fps = get_fps_ffmpeg(video_path)
 
168
  inputs=[
169
  gr.Video(label="Input Video"),
170
  gr.Textbox(label="Prompt", value="Describe the camera motion in this video."),
171
+ gr.Checkbox(label="Use Flash Attention", value=False),
172
+ gr.Checkbox(label="Apply Quantization", value=True),
173
  ],
174
  outputs=gr.JSON(label="Output JSON"),
175
  title="",
pyproject.toml CHANGED
@@ -13,4 +13,5 @@ dependencies = [
13
  "torchvision==0.19.0",
14
  "ffmpeg-python>=0.2.0",
15
  "accelerate==0.32.1",
 
16
  ]
 
13
  "torchvision==0.19.0",
14
  "ffmpeg-python>=0.2.0",
15
  "accelerate==0.32.1",
16
+ "bitsandbytes==0.41.1",
17
  ]