AnsenH's picture
chore: move video utils to util functions
72851a0
import gradio as gr
from run_on_video.run import MomentDETRPredictor
import torch
from lbhd.infer import lbhd_predict
import os
import subprocess
from utils.export_utils import trim_video
DESCRIPTION = """
_This Space demonstrates model [QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries](https://arxiv.org/abs/2107.09609), NeurIPS 2021, by [Jie Lei](http://www.cs.unc.edu/~jielei/), [Tamara L. Berg](http://tamaraberg.com/), [Mohit Bansal](http://www.cs.unc.edu/~mbansal/)_
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
clip_model_name_or_path = "ViT-B/32"
moment_detr_predictor = MomentDETRPredictor(
ckpt_path=ckpt_path,
clip_model_name_or_path=clip_model_name_or_path,
device=device
)
def display_prediction(result):
return f'### Start time: {result[0]:.2f}, End time: {result[1]:.2f}, Score: {result[2]:.2f}'
with gr.Blocks(theme=gr.themes.Default()) as demo:
output_videos = gr.State(None)
output_lbhd_videos = gr.State(None)
moment_prediction = gr.State(None)
our_prediction = gr.State(None)
gr.HTML("""<h2 align="center"> ๐ŸŽž๏ธ Highlight Detection with MomentDETR </h2>""")
gr.Markdown(DESCRIPTION)
with gr.Column():
with gr.Row():
with gr.Blocks():
with gr.Column():
gr.HTML("""<h3 align="center"> Input Video </h3>""")
input_video = gr.Video(label="Please input mp4", height=400)
with gr.Blocks():
with gr.Column():
gr.HTML("""<h3 align="center"> MomentDETR Result </h3>""")
playable_video = gr.Video(height=400)
display_score = gr.Markdown("### Start time, End time, Score")
with gr.Blocks():
with gr.Column():
gr.HTML("""<h3 align="center"> Ours Result </h3>""")
our_result_video = gr.Video(height=400)
display_clip_score = gr.Markdown("### Start time, End time, Score")
with gr.Row():
with gr.Column():
retrieval_text = gr.Textbox(
label="Query text",
placeholder="What should be highlighted?",
visible=True
)
submit = gr.Button("Submit")
with gr.Column():
radio_button = gr.Radio(
choices=[i+1 for i in range(10)],
label="Top 10",
value=1
)
def update_video_player(radio_value, output_videos, output_lbhd_videos, moment_prediction, our_prediction):
if output_videos is None or moment_prediction is None:
return [None, None, None, None]
return {
playable_video: output_videos[radio_value-1],
our_result_video: output_lbhd_videos[min(radio_value-1, len(output_lbhd_videos)-1)],
display_score: display_prediction(moment_prediction[radio_value-1]),
display_clip_score: display_prediction(our_prediction[min(radio_value-1, len(output_lbhd_videos)-1)])
}
def submit_video(input_video, retrieval_text):
ext = os.path.splitext(input_video)[-1].lower()
if ext == ".mov":
output_file = os.path.join(input_video.replace(".mov", ".mp4"))
subprocess.call(['ffmpeg', '-i', input_video, "-vf", "scale=320:-2", output_file])
print(f'== video path: {input_video}')
print(f'== retrieval_text: {retrieval_text}')
if input_video is None:
return [None, None, None, None, None, None, None, None, None, 1]
if retrieval_text is None:
retrieval_text = ''
predictions, video_frames = moment_detr_predictor.localize_moment(
video_path=input_video,
query_list=[retrieval_text]
)
predictions = predictions[0]['pred_relevant_windows']
print(f'== Moment prediction: {predictions}')
output_files = [ trim_video(
video_path= output_file if ext == ".mov" else input_video,
start=predictions[i][0],
end=predictions[i][1],
output_file=f'{i}.mp4'
) for i in range(10)]
lbhd_predictions = lbhd_predict(input_video)
print(f'== lbhd_predictions: {lbhd_predictions}')
output_files_lbhd = [ trim_video(
video_path= output_file if ext == ".mov" else input_video,
start=lbhd_predictions[i][0],
end=lbhd_predictions[i][1],
output_file=f'{i}_lbhd.mp4'
) for i in range(min(10, len(lbhd_predictions)))]
return [
output_file if ext == ".mov" else input_video,
output_files,
output_files_lbhd,
predictions,
lbhd_predictions,
output_files[0],
output_files_lbhd[0],
display_prediction(predictions[0]),
display_prediction(lbhd_predictions[0]),
1
]
radio_button.change(
fn=update_video_player,
inputs=[radio_button, output_videos, output_lbhd_videos, moment_prediction, our_prediction],
outputs=[playable_video, our_result_video, display_score, display_clip_score]
)
submit.click(
fn=submit_video,
inputs=[input_video, retrieval_text],
outputs=[input_video, output_videos, output_lbhd_videos, moment_prediction, our_prediction, playable_video, our_result_video, display_score, display_clip_score, radio_button]
)
demo.launch()