Spaces:
Runtime error
Runtime error
import torch | |
import argparse | |
import os | |
from moviepy.editor import * | |
from PIL import Image | |
from lbhd.i2v import i2v_transform, load_weight, EMBED_DIM | |
from lbhd.batch_image_transforms import batch_transform_val | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
CWD = os.path.dirname(__file__) | |
WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt') | |
i2v = load_weight(WEIGHT_i2v) | |
i2v.to(device) | |
i2v_transform.to(device) | |
WEIGHT_scoring = os.path.join(CWD, '..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt') | |
checkpoint = torch.load(WEIGHT_scoring, map_location=device) | |
scoringModel = checkpoint['model'] | |
def sample_clips(x, frames_per_clip: int): | |
print("x shape", x.shape, "frames_per_clip", frames_per_clip) | |
x = torch.stack( [ x[i:i+frames_per_clip] for i in range(len(x) - frames_per_clip + 1) ] ) | |
return x | |
def sample_frames(file_name, fps): | |
import cv2 | |
# Opens the Video file | |
cap= cv2.VideoCapture(file_name) | |
i=0 | |
frames = [] | |
while(cap.isOpened()): | |
ret, frame = cap.read() | |
if ret == False: | |
break | |
if i % fps == 0: | |
frames.append(Image.fromarray(frame)) | |
i+=1 | |
cap.release() | |
cv2.destroyAllWindows() | |
return frames | |
def videofile_to_frames(filename, sample_every_second=True): | |
clip = VideoFileClip(filename) | |
frames = clip.iter_frames() | |
frames = [Image.fromarray(frame) for index, frame in enumerate(frames) if index % int(clip.fps) == 0] | |
print('clip.fps', clip.fps, 'Number of frames in video is:', len(frames)) | |
return frames | |
def frames_to_vectors(frames, frames_per_clip, use_frame_diff:bool=False, transform=batch_transform_val, i2v=i2v, i2v_transform=i2v_transform): | |
x = torch.stack(transform(frames)) | |
x = sample_clips(x, frames_per_clip).to(device) # x.size = (n_clips, frames_per_clip, 3, 224, 224) | |
n_clips = x.size(0) | |
x = x.view(-1,3,224,224) # (n_clips * frames_per_clip, 3, 224 ,224) | |
with torch.no_grad(): | |
x, _ = i2v(i2v_transform(x)) # (n_clips * frames_per_clip, EMBED_DIM) | |
x = x.view(n_clips, -1, EMBED_DIM) # (n_clips, frames_per_clip, EMBED_DIM) | |
if use_frame_diff: | |
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] | |
x = x.view(n_clips, -1) # (n_clips, frames_per_clip * EMBED_DIM) | |
return x | |
def videofile_to_scores(videofile, model): | |
frames = videofile_to_frames(videofile) | |
x = frames_to_vectors(frames, model.frames_per_clip) | |
with torch.no_grad(): | |
return torch.sigmoid(model(x)).squeeze().cpu().numpy() | |
def frames_to_scores(frames, model, use_frame_diff): | |
x = frames_to_vectors(frames, model.frames_per_clip, use_frame_diff) | |
with torch.no_grad(): | |
return torch.sigmoid(model(x)).squeeze().cpu().numpy() | |
def prepare_output(scores): | |
output = [[clip_id, clip_id+2, s] for (clip_id, s) in enumerate(scores)] | |
return sorted(output, key=lambda result: result[2], reverse=True) | |
def lbhd_predict(video_file): | |
scores = videofile_to_scores(video_file, scoringModel) | |
return prepare_output(scores) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint', type=str, default='../weight/ckpt_epoch_59_loss_0.3066582295561343.ckpt') | |
parser.add_argument('--videofile', type=str, default=None) | |
args = parser.parse_args() | |
print('Arguments:', args) | |
checkpoint = torch.load(args.checkpoint, map_location='cpu') | |
scoringModel = checkpoint['model'] | |
frames_per_clip = scoringModel.frames_per_clip | |
scores = videofile_to_scores(args.videofile, scoringModel) | |
print(scores) | |
print(prepare_output(scores)) | |