John Ho commited on
Commit
f10889a
·
1 Parent(s): f87fafd

make flash attention an input

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -24,7 +24,7 @@ subprocess.run(
24
 
25
  DTYPE = torch.bfloat16
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- logger.info(f"Device: {device}, dtype: {dtype}")
28
 
29
 
30
  def get_fps_ffmpeg(video_path: str):
@@ -65,11 +65,13 @@ def load_model(
65
 
66
  @spaces.GPU(duration=120)
67
  def inference(
68
- video_path: str, prompt: str = "Describe the camera motion in this video."
 
 
69
  ):
70
  # default processor
71
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
72
- model = load_model(use_flash_attention=True)
73
  fps = get_fps_ffmpeg(video_path)
74
  logger.info(f"{os.path.basename(video_path)} FPS: {fps}")
75
  messages = [
@@ -122,6 +124,7 @@ demo = gr.Interface(
122
  inputs=[
123
  gr.Video(label="Input Video"),
124
  gr.Textbox(label="Prompt", value="Describe the camera motion in this video."),
 
125
  ],
126
  outputs=gr.JSON(label="Output JSON"),
127
  title="",
 
24
 
25
  DTYPE = torch.bfloat16
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ logger.info(f"Device: {DEVICE}, dtype: {DTYPE}")
28
 
29
 
30
  def get_fps_ffmpeg(video_path: str):
 
65
 
66
  @spaces.GPU(duration=120)
67
  def inference(
68
+ video_path: str,
69
+ prompt: str = "Describe the camera motion in this video.",
70
+ use_flash_attention: bool = True,
71
  ):
72
  # default processor
73
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
74
+ model = load_model(use_flash_attention=use_flash_attention)
75
  fps = get_fps_ffmpeg(video_path)
76
  logger.info(f"{os.path.basename(video_path)} FPS: {fps}")
77
  messages = [
 
124
  inputs=[
125
  gr.Video(label="Input Video"),
126
  gr.Textbox(label="Prompt", value="Describe the camera motion in this video."),
127
+ gr.Checkbox(label="Use Flash Attention", value=True),
128
  ],
129
  outputs=gr.JSON(label="Output JSON"),
130
  title="",