import gradio as gr from run_on_video.run import MomentDETRPredictor from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip import torch DESCRIPTION = """ _This Space demonstrates model [QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries](https://arxiv.org/abs/2107.09609), NeurIPS 2021, by [Jie Lei](http://www.cs.unc.edu/~jielei/), [Tamara L. Berg](http://tamaraberg.com/), [Mohit Bansal](http://www.cs.unc.edu/~mbansal/)_ """ ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt" clip_model_name_or_path = "ViT-B/32" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") moment_detr_predictor = MomentDETRPredictor( ckpt_path=ckpt_path, clip_model_name_or_path=clip_model_name_or_path, device=device ) def trim_video(video_path, start, end, output_file='result.mp4'): ffmpeg_extract_subclip(video_path, start, end, targetname=output_file) return output_file def display_prediction(result): return f'Moment({result[0]} ~ {result[1]}), Score: {result[2]}' with gr.Blocks(theme=gr.themes.Default()) as demo: output_videos = gr.State([]) moment_prediction = gr.State([]) gr.HTML("""

🎞️ Highlight Detection with MomentDETR

""") gr.Markdown(DESCRIPTION) with gr.Column(): with gr.Row(): with gr.Blocks(): with gr.Column(): gr.HTML("""

Input Video

""") input_video = gr.Video(label="Please input mp4", height=400) with gr.Blocks(): with gr.Column(): gr.HTML("""

Highlight Videos

""") playable_video = gr.Video(height=400) with gr.Row(): with gr.Column(): retrieval_text = gr.Textbox( label="Query text", placeholder="What should be highlighted?", visible=True ) submit =gr.Button("Submit") with gr.Column(): display_score = gr.Markdown("### Moment Score: ") radio_button = gr.Radio( choices=[i+1 for i in range(10)], label="Moments", value=1 ) def update_video_player(radio_value, output_videos, moment_prediction): return { playable_video: output_videos[radio_value-1], display_score: display_prediction(moment_prediction[radio_value-1]) } def submit_video(input_video, retrieval_text): print(f'== video path: {input_video}') print(f'== retrieval_text: {retrieval_text}') if retrieval_text is None: retrieval_text = '' predictions, video_frames = moment_detr_predictor.localize_moment( video_path=input_video, query_list=[retrieval_text] ) predictions = predictions[0]['pred_relevant_windows'] pred_windows = [[pred[0], pred[1]]for pred in predictions] output_files = [ trim_video( video_path=input_video, start=pred_windows[i][0], end=pred_windows[i][1], output_file=f'{i}.mp4' ) for i in range(10)] return { output_videos: output_files, moment_prediction: predictions, playable_video: output_files[0], display_score: display_prediction(predictions[0]) } radio_button.change( fn=update_video_player, inputs=[radio_button, output_videos, moment_prediction], outputs=[playable_video, display_score] ) submit.click( fn=submit_video, inputs=[input_video, retrieval_text], outputs=[output_videos, moment_prediction, playable_video, display_score] ) demo.launch()