Spaces:
Runtime error
Runtime error
import torch | |
from run_on_video.data_utils import ClipFeatureExtractor | |
from run_on_video.model_utils import build_inference_model | |
from utils.tensor_utils import pad_sequences_1d | |
from moment_detr.span_utils import span_cxw_to_xx | |
from utils.basic_utils import l2_normalize_np_array | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
from PIL import Image | |
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip | |
from moviepy.video.io.VideoFileClip import VideoFileClip | |
class MomentDETRPredictor: | |
def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"): | |
self.clip_len = 2 # seconds | |
self.device = device | |
print("Loading feature extractors...") | |
self.feature_extractor = ClipFeatureExtractor( | |
framerate=1/self.clip_len, size=224, centercrop=True, | |
model_name_or_path=clip_model_name_or_path, device=device | |
) | |
print("Loading trained Moment-DETR model...") | |
self.model = build_inference_model(ckpt_path).to(self.device) | |
def localize_moment(self, video_path, query_list): | |
""" | |
Args: | |
video_path: str, path to the video file | |
query_list: List[str], each str is a query for this video | |
""" | |
# construct model inputs | |
n_query = len(query_list) | |
video_feats, video_frames = self.feature_extractor.encode_video(video_path) | |
video_feats = F.normalize(video_feats, dim=-1, eps=1e-5) | |
n_frames = len(video_feats) | |
# add tef | |
tef_st = torch.arange(0, n_frames, 1.0) / n_frames | |
tef_ed = tef_st + 1.0 / n_frames | |
tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device) # (n_frames, 2) | |
video_feats = torch.cat([video_feats, tef], dim=1) | |
assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \ | |
"to 150 secs (i.e., 75 2-sec clips) in length" | |
video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1) # (#text, T, d) | |
video_mask = torch.ones(n_query, n_frames).to(self.device) | |
query_feats = self.feature_extractor.encode_text(query_list) # #text * (L, d) | |
query_feats, query_mask = pad_sequences_1d( | |
query_feats, dtype=torch.float32, device=self.device, fixed_length=None) | |
query_feats = F.normalize(query_feats, dim=-1, eps=1e-5) | |
model_inputs = dict( | |
src_vid=video_feats, | |
src_vid_mask=video_mask, | |
src_txt=query_feats, | |
src_txt_mask=query_mask | |
) | |
# decode outputs | |
outputs = self.model(**model_inputs) | |
# #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query | |
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #moment_queries=10, #classes=2) | |
scores = prob[..., 0] # * (batch_size, #moment_queries) foreground label is 0, we directly take it | |
pred_spans = outputs["pred_spans"] # (bsz, #moment_queries, 2) | |
print(pred_spans) | |
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L) | |
saliency_scores = [] | |
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() | |
for j in range(len(valid_vid_lengths)): | |
_score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist() | |
_score = [round(e, 4) for e in _score] | |
saliency_scores.append(_score) | |
# compose predictions | |
predictions = [] | |
video_duration = n_frames * self.clip_len | |
for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())): | |
spans = span_cxw_to_xx(spans) * video_duration | |
# # (#queries, 3), [st(float), ed(float), score(float)] | |
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() | |
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) | |
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] | |
cur_query_pred = dict( | |
query=query_list[idx], # str | |
vid=video_path, | |
pred_relevant_windows=cur_ranked_preds, # List([st(float), ed(float), score(float)]) | |
pred_saliency_scores=saliency_scores[idx] # List(float), len==n_frames, scores for each frame | |
) | |
predictions.append(cur_query_pred) | |
return predictions, video_frames | |
def run_example(): | |
# load example data | |
from utils.basic_utils import load_jsonl | |
video_dir = "run_on_video/example/testing_videos/dogs" | |
#video_path = "run_on_video/example/testing_videos/" | |
query_path = "run_on_video/example/queries_highlight.jsonl" | |
queries = load_jsonl(query_path) | |
query_text_list = [e["query"] for e in queries] | |
ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt" | |
# run predictions | |
print("Build models...") | |
clip_model_name_or_path = "ViT-B/32" | |
# clip_model_name_or_path = "tmp/ViT-B-32.pt" | |
moment_detr_predictor = MomentDETRPredictor( | |
ckpt_path=ckpt_path, | |
clip_model_name_or_path=clip_model_name_or_path, | |
device="cuda" | |
) | |
print("Run prediction...") | |
video_paths = [os.path.join(video_dir, e) for e in os.listdir(video_dir)] | |
#video_paths = ["run_on_video/example/testing_videos/celebration_18s.mov"] | |
for video_path in video_paths: | |
output_dir = os.path.join("run_on_video/example/output/dog/empty_str", os.path.basename(video_path)) | |
predictions, video_frames = moment_detr_predictor.localize_moment( | |
video_path=video_path, query_list=query_text_list) | |
#check output directory exists | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# print data | |
for idx, query_data in enumerate(queries): | |
print("-"*30 + f"idx{idx}") | |
print(f">> query: {query_data['query']}") | |
print(f">> video_path: {video_path}") | |
#print(f">> GT moments: {query_data['relevant_windows']}") | |
print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): " | |
f"{predictions[idx]['pred_relevant_windows']}") | |
#print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}") | |
print(f">> Predicted saliency scores (for all 2-sec clip): " | |
f"{predictions[idx]['pred_saliency_scores']}") | |
#output the retrievved moments | |
#sort the moment by the third element in the list | |
predictions[idx]['pred_relevant_windows'] = sorted(predictions[idx]['pred_relevant_windows'], key=lambda x: x[2], reverse=True) | |
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']): | |
print(start_time, end_time, score) | |
ffmpeg_extract_subclip(video_path, start_time, end_time, targetname=os.path.join(output_dir, f'moment_{i}.mp4')) | |
#store the sorted pred_relevant_windows scores and time | |
with open(os.path.join(output_dir, 'moment_scores.txt'), 'w') as f: | |
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']): | |
f.write(str(i)+'. '+str(start_time)+' '+str(end_time)+' '+str(score) + '\n') | |
#To-dos: save the video frames sorted by pred_saliency_scores | |
sorted_frames = [frame for _, frame in sorted(zip(predictions[idx]['pred_saliency_scores'], video_frames), reverse=True)] | |
#save the sorted scores and also the original index | |
sorted_scores = sorted(predictions[idx]['pred_saliency_scores'], reverse=True) | |
print(sorted_scores) | |
#save frames to output directory | |
for i, frame in enumerate(sorted_frames): | |
#transfer frame from tensor to PIL image | |
frame = frame.permute(1, 2, 0).cpu().numpy() | |
frame = frame.astype(np.uint8) | |
frame = Image.fromarray(frame) | |
frame.save(os.path.join(output_dir, str(i) + '.jpg')) | |
#save scores to output directory | |
with open(os.path.join(output_dir, 'scores.txt'), 'w') as f: | |
for i, score in enumerate(sorted_scores): | |
f.write(str(i)+'. '+str(score) + '\n') | |
if __name__ == "__main__": | |
run_example() | |