AnsenH's picture
feat: add pipeline to predict
79d80e3
raw
history blame
8.46 kB
import torch
from run_on_video.data_utils import ClipFeatureExtractor
from run_on_video.model_utils import build_inference_model
from utils.tensor_utils import pad_sequences_1d
from moment_detr.span_utils import span_cxw_to_xx
from utils.basic_utils import l2_normalize_np_array
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.video.io.VideoFileClip import VideoFileClip
class MomentDETRPredictor:
def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"):
self.clip_len = 2 # seconds
self.device = device
print("Loading feature extractors...")
self.feature_extractor = ClipFeatureExtractor(
framerate=1/self.clip_len, size=224, centercrop=True,
model_name_or_path=clip_model_name_or_path, device=device
)
print("Loading trained Moment-DETR model...")
self.model = build_inference_model(ckpt_path).to(self.device)
@torch.no_grad()
def localize_moment(self, video_path, query_list):
"""
Args:
video_path: str, path to the video file
query_list: List[str], each str is a query for this video
"""
# construct model inputs
n_query = len(query_list)
video_feats, video_frames = self.feature_extractor.encode_video(video_path)
video_feats = F.normalize(video_feats, dim=-1, eps=1e-5)
n_frames = len(video_feats)
# add tef
tef_st = torch.arange(0, n_frames, 1.0) / n_frames
tef_ed = tef_st + 1.0 / n_frames
tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device) # (n_frames, 2)
video_feats = torch.cat([video_feats, tef], dim=1)
assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \
"to 150 secs (i.e., 75 2-sec clips) in length"
video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1) # (#text, T, d)
video_mask = torch.ones(n_query, n_frames).to(self.device)
query_feats = self.feature_extractor.encode_text(query_list) # #text * (L, d)
query_feats, query_mask = pad_sequences_1d(
query_feats, dtype=torch.float32, device=self.device, fixed_length=None)
query_feats = F.normalize(query_feats, dim=-1, eps=1e-5)
model_inputs = dict(
src_vid=video_feats,
src_vid_mask=video_mask,
src_txt=query_feats,
src_txt_mask=query_mask
)
# decode outputs
outputs = self.model(**model_inputs)
# #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #moment_queries=10, #classes=2)
scores = prob[..., 0] # * (batch_size, #moment_queries) foreground label is 0, we directly take it
pred_spans = outputs["pred_spans"] # (bsz, #moment_queries, 2)
print(pred_spans)
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
saliency_scores = []
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
for j in range(len(valid_vid_lengths)):
_score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist()
_score = [round(e, 4) for e in _score]
saliency_scores.append(_score)
# compose predictions
predictions = []
video_duration = n_frames * self.clip_len
for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())):
spans = span_cxw_to_xx(spans) * video_duration
# # (#queries, 3), [st(float), ed(float), score(float)]
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
cur_query_pred = dict(
query=query_list[idx], # str
vid=video_path,
pred_relevant_windows=cur_ranked_preds, # List([st(float), ed(float), score(float)])
pred_saliency_scores=saliency_scores[idx] # List(float), len==n_frames, scores for each frame
)
predictions.append(cur_query_pred)
return predictions, video_frames
def run_example():
# load example data
from utils.basic_utils import load_jsonl
video_dir = "run_on_video/example/testing_videos/dogs"
#video_path = "run_on_video/example/testing_videos/"
query_path = "run_on_video/example/queries_highlight.jsonl"
queries = load_jsonl(query_path)
query_text_list = [e["query"] for e in queries]
ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
# run predictions
print("Build models...")
clip_model_name_or_path = "ViT-B/32"
# clip_model_name_or_path = "tmp/ViT-B-32.pt"
moment_detr_predictor = MomentDETRPredictor(
ckpt_path=ckpt_path,
clip_model_name_or_path=clip_model_name_or_path,
device="cuda"
)
print("Run prediction...")
video_paths = [os.path.join(video_dir, e) for e in os.listdir(video_dir)]
#video_paths = ["run_on_video/example/testing_videos/celebration_18s.mov"]
for video_path in video_paths:
output_dir = os.path.join("run_on_video/example/output/dog/empty_str", os.path.basename(video_path))
predictions, video_frames = moment_detr_predictor.localize_moment(
video_path=video_path, query_list=query_text_list)
#check output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# print data
for idx, query_data in enumerate(queries):
print("-"*30 + f"idx{idx}")
print(f">> query: {query_data['query']}")
print(f">> video_path: {video_path}")
#print(f">> GT moments: {query_data['relevant_windows']}")
print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): "
f"{predictions[idx]['pred_relevant_windows']}")
#print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}")
print(f">> Predicted saliency scores (for all 2-sec clip): "
f"{predictions[idx]['pred_saliency_scores']}")
#output the retrievved moments
#sort the moment by the third element in the list
predictions[idx]['pred_relevant_windows'] = sorted(predictions[idx]['pred_relevant_windows'], key=lambda x: x[2], reverse=True)
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
print(start_time, end_time, score)
ffmpeg_extract_subclip(video_path, start_time, end_time, targetname=os.path.join(output_dir, f'moment_{i}.mp4'))
#store the sorted pred_relevant_windows scores and time
with open(os.path.join(output_dir, 'moment_scores.txt'), 'w') as f:
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
f.write(str(i)+'. '+str(start_time)+' '+str(end_time)+' '+str(score) + '\n')
#To-dos: save the video frames sorted by pred_saliency_scores
sorted_frames = [frame for _, frame in sorted(zip(predictions[idx]['pred_saliency_scores'], video_frames), reverse=True)]
#save the sorted scores and also the original index
sorted_scores = sorted(predictions[idx]['pred_saliency_scores'], reverse=True)
print(sorted_scores)
#save frames to output directory
for i, frame in enumerate(sorted_frames):
#transfer frame from tensor to PIL image
frame = frame.permute(1, 2, 0).cpu().numpy()
frame = frame.astype(np.uint8)
frame = Image.fromarray(frame)
frame.save(os.path.join(output_dir, str(i) + '.jpg'))
#save scores to output directory
with open(os.path.join(output_dir, 'scores.txt'), 'w') as f:
for i, score in enumerate(sorted_scores):
f.write(str(i)+'. '+str(score) + '\n')
if __name__ == "__main__":
run_example()