AnsenH's picture
bugfix: short video index issue
ec35ab9
import torch
import argparse
import os
from moviepy.editor import *
from PIL import Image
from lbhd.i2v import i2v_transform, load_weight, EMBED_DIM
from lbhd.batch_image_transforms import batch_transform_val
device = 'cuda' if torch.cuda.is_available() else 'cpu'
CWD = os.path.dirname(__file__)
WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt')
i2v = load_weight(WEIGHT_i2v)
i2v.to(device)
i2v_transform.to(device)
WEIGHT_scoring = os.path.join(CWD, '..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt')
checkpoint = torch.load(WEIGHT_scoring, map_location=device)
scoringModel = checkpoint['model']
def sample_clips(x, frames_per_clip: int):
print("x shape", x.shape, "frames_per_clip", frames_per_clip)
x = torch.stack( [ x[i:i+frames_per_clip] for i in range(len(x) - frames_per_clip + 1) ] )
return x
def sample_frames(file_name, fps):
import cv2
# Opens the Video file
cap= cv2.VideoCapture(file_name)
i=0
frames = []
while(cap.isOpened()):
ret, frame = cap.read()
if ret == False:
break
if i % fps == 0:
frames.append(Image.fromarray(frame))
i+=1
cap.release()
cv2.destroyAllWindows()
return frames
def videofile_to_frames(filename, sample_every_second=True):
clip = VideoFileClip(filename)
frames = clip.iter_frames()
frames = [Image.fromarray(frame) for index, frame in enumerate(frames) if index % int(clip.fps) == 0]
print('clip.fps', clip.fps, 'Number of frames in video is:', len(frames))
return frames
def frames_to_vectors(frames, frames_per_clip, use_frame_diff:bool=False, transform=batch_transform_val, i2v=i2v, i2v_transform=i2v_transform):
x = torch.stack(transform(frames))
x = sample_clips(x, frames_per_clip).to(device) # x.size = (n_clips, frames_per_clip, 3, 224, 224)
n_clips = x.size(0)
x = x.view(-1,3,224,224) # (n_clips * frames_per_clip, 3, 224 ,224)
with torch.no_grad():
x, _ = i2v(i2v_transform(x)) # (n_clips * frames_per_clip, EMBED_DIM)
x = x.view(n_clips, -1, EMBED_DIM) # (n_clips, frames_per_clip, EMBED_DIM)
if use_frame_diff:
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
x = x.view(n_clips, -1) # (n_clips, frames_per_clip * EMBED_DIM)
return x
def videofile_to_scores(videofile, model):
frames = videofile_to_frames(videofile)
x = frames_to_vectors(frames, model.frames_per_clip)
with torch.no_grad():
return torch.sigmoid(model(x)).squeeze().cpu().numpy()
def frames_to_scores(frames, model, use_frame_diff):
x = frames_to_vectors(frames, model.frames_per_clip, use_frame_diff)
with torch.no_grad():
return torch.sigmoid(model(x)).squeeze().cpu().numpy()
def prepare_output(scores):
output = [[clip_id, clip_id+2, s] for (clip_id, s) in enumerate(scores)]
return sorted(output, key=lambda result: result[2], reverse=True)
def lbhd_predict(video_file):
scores = videofile_to_scores(video_file, scoringModel)
return prepare_output(scores)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='../weight/ckpt_epoch_59_loss_0.3066582295561343.ckpt')
parser.add_argument('--videofile', type=str, default=None)
args = parser.parse_args()
print('Arguments:', args)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
scoringModel = checkpoint['model']
frames_per_clip = scoringModel.frames_per_clip
scores = videofile_to_scores(args.videofile, scoringModel)
print(scores)
print(prepare_output(scores))