import torch from run_on_video.data_utils import ClipFeatureExtractor from run_on_video.model_utils import build_inference_model from utils.tensor_utils import pad_sequences_1d from moment_detr.span_utils import span_cxw_to_xx from utils.basic_utils import l2_normalize_np_array import torch.nn.functional as F import numpy as np import os from PIL import Image from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip from moviepy.video.io.VideoFileClip import VideoFileClip class MomentDETRPredictor: def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"): self.clip_len = 2 # seconds self.device = device print("Loading feature extractors...") self.feature_extractor = ClipFeatureExtractor( framerate=1/self.clip_len, size=224, centercrop=True, model_name_or_path=clip_model_name_or_path, device=device ) print("Loading trained Moment-DETR model...") self.model = build_inference_model(ckpt_path).to(self.device) self.model.eval() @torch.no_grad() def localize_moment(self, video_path, query_list): """ Args: video_path: str, path to the video file query_list: List[str], each str is a query for this video """ # construct model inputs n_query = len(query_list) video_feats, video_frames = self.feature_extractor.encode_video(video_path) video_feats = F.normalize(video_feats, dim=-1, eps=1e-5) n_frames = len(video_feats) # add tef tef_st = torch.arange(0, n_frames, 1.0) / n_frames tef_ed = tef_st + 1.0 / n_frames tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device) # (n_frames, 2) video_feats = torch.cat([video_feats, tef], dim=1) assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \ "to 150 secs (i.e., 75 2-sec clips) in length" video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1) # (#text, T, d) video_mask = torch.ones(n_query, n_frames).to(self.device) query_feats = self.feature_extractor.encode_text(query_list) # #text * (L, d) query_feats, query_mask = pad_sequences_1d( query_feats, dtype=torch.float32, device=self.device, fixed_length=None) query_feats = F.normalize(query_feats, dim=-1, eps=1e-5) model_inputs = dict( src_vid=video_feats, src_vid_mask=video_mask, src_txt=query_feats, src_txt_mask=query_mask ) # decode outputs outputs = self.model(**model_inputs) # #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #moment_queries=10, #classes=2) scores = prob[..., 0] # * (batch_size, #moment_queries) foreground label is 0, we directly take it pred_spans = outputs["pred_spans"] # (bsz, #moment_queries, 2) print(pred_spans) _saliency_scores = outputs["saliency_scores"].half() # (bsz, L) saliency_scores = [] valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() for j in range(len(valid_vid_lengths)): _score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist() _score = [round(e, 4) for e in _score] saliency_scores.append(_score) # compose predictions predictions = [] video_duration = n_frames * self.clip_len for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())): spans = span_cxw_to_xx(spans) * video_duration # # (#queries, 3), [st(float), ed(float), score(float)] cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] cur_query_pred = dict( query=query_list[idx], # str vid=video_path, pred_relevant_windows=cur_ranked_preds, # List([st(float), ed(float), score(float)]) pred_saliency_scores=saliency_scores[idx] # List(float), len==n_frames, scores for each frame ) predictions.append(cur_query_pred) return predictions, video_frames def run_example(): # load example data from utils.basic_utils import load_jsonl video_dir = "run_on_video/example/testing_videos/dogs" #video_path = "run_on_video/example/testing_videos/" query_path = "run_on_video/example/queries_highlight.jsonl" queries = load_jsonl(query_path) query_text_list = [e["query"] for e in queries] ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt" # run predictions print("Build models...") clip_model_name_or_path = "ViT-B/32" # clip_model_name_or_path = "tmp/ViT-B-32.pt" moment_detr_predictor = MomentDETRPredictor( ckpt_path=ckpt_path, clip_model_name_or_path=clip_model_name_or_path, device="cuda" ) print("Run prediction...") video_paths = [os.path.join(video_dir, e) for e in os.listdir(video_dir)] #video_paths = ["run_on_video/example/testing_videos/celebration_18s.mov"] for video_path in video_paths: output_dir = os.path.join("run_on_video/example/output/dog/empty_str", os.path.basename(video_path)) predictions, video_frames = moment_detr_predictor.localize_moment( video_path=video_path, query_list=query_text_list) #check output directory exists if not os.path.exists(output_dir): os.makedirs(output_dir) # print data for idx, query_data in enumerate(queries): print("-"*30 + f"idx{idx}") print(f">> query: {query_data['query']}") print(f">> video_path: {video_path}") #print(f">> GT moments: {query_data['relevant_windows']}") print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): " f"{predictions[idx]['pred_relevant_windows']}") #print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}") print(f">> Predicted saliency scores (for all 2-sec clip): " f"{predictions[idx]['pred_saliency_scores']}") #output the retrievved moments #sort the moment by the third element in the list predictions[idx]['pred_relevant_windows'] = sorted(predictions[idx]['pred_relevant_windows'], key=lambda x: x[2], reverse=True) for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']): print(start_time, end_time, score) ffmpeg_extract_subclip(video_path, start_time, end_time, targetname=os.path.join(output_dir, f'moment_{i}.mp4')) #store the sorted pred_relevant_windows scores and time with open(os.path.join(output_dir, 'moment_scores.txt'), 'w') as f: for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']): f.write(str(i)+'. '+str(start_time)+' '+str(end_time)+' '+str(score) + '\n') #To-dos: save the video frames sorted by pred_saliency_scores sorted_frames = [frame for _, frame in sorted(zip(predictions[idx]['pred_saliency_scores'], video_frames), reverse=True)] #save the sorted scores and also the original index sorted_scores = sorted(predictions[idx]['pred_saliency_scores'], reverse=True) print(sorted_scores) #save frames to output directory for i, frame in enumerate(sorted_frames): #transfer frame from tensor to PIL image frame = frame.permute(1, 2, 0).cpu().numpy() frame = frame.astype(np.uint8) frame = Image.fromarray(frame) frame.save(os.path.join(output_dir, str(i) + '.jpg')) #save scores to output directory with open(os.path.join(output_dir, 'scores.txt'), 'w') as f: for i, score in enumerate(sorted_scores): f.write(str(i)+'. '+str(score) + '\n') if __name__ == "__main__": run_example()