import torch import argparse import os from moviepy.editor import * from PIL import Image from lbhd.i2v import i2v_transform, load_weight, EMBED_DIM from lbhd.batch_image_transforms import batch_transform_val device = 'cuda' if torch.cuda.is_available() else 'cpu' CWD = os.path.dirname(__file__) WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt') i2v = load_weight(WEIGHT_i2v) i2v.to(device) i2v_transform.to(device) WEIGHT_scoring = os.path.join(CWD, '..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt') checkpoint = torch.load(WEIGHT_scoring, map_location=device) scoringModel = checkpoint['model'] def sample_clips(x, frames_per_clip: int): print("x shape", x.shape, "frames_per_clip", frames_per_clip) x = torch.stack( [ x[i:i+frames_per_clip] for i in range(len(x) - frames_per_clip + 1) ] ) return x def sample_frames(file_name, fps): import cv2 # Opens the Video file cap= cv2.VideoCapture(file_name) i=0 frames = [] while(cap.isOpened()): ret, frame = cap.read() if ret == False: break if i % fps == 0: frames.append(Image.fromarray(frame)) i+=1 cap.release() cv2.destroyAllWindows() return frames def videofile_to_frames(filename, sample_every_second=True): clip = VideoFileClip(filename) frames = clip.iter_frames() frames = [Image.fromarray(frame) for index, frame in enumerate(frames) if index % int(clip.fps) == 0] print('clip.fps', clip.fps, 'Number of frames in video is:', len(frames)) return frames def frames_to_vectors(frames, frames_per_clip, use_frame_diff:bool=False, transform=batch_transform_val, i2v=i2v, i2v_transform=i2v_transform): x = torch.stack(transform(frames)) x = sample_clips(x, frames_per_clip).to(device) # x.size = (n_clips, frames_per_clip, 3, 224, 224) n_clips = x.size(0) x = x.view(-1,3,224,224) # (n_clips * frames_per_clip, 3, 224 ,224) with torch.no_grad(): x, _ = i2v(i2v_transform(x)) # (n_clips * frames_per_clip, EMBED_DIM) x = x.view(n_clips, -1, EMBED_DIM) # (n_clips, frames_per_clip, EMBED_DIM) if use_frame_diff: x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] x = x.view(n_clips, -1) # (n_clips, frames_per_clip * EMBED_DIM) return x def videofile_to_scores(videofile, model): frames = videofile_to_frames(videofile) x = frames_to_vectors(frames, model.frames_per_clip) with torch.no_grad(): return torch.sigmoid(model(x)).squeeze().cpu().numpy() def frames_to_scores(frames, model, use_frame_diff): x = frames_to_vectors(frames, model.frames_per_clip, use_frame_diff) with torch.no_grad(): return torch.sigmoid(model(x)).squeeze().cpu().numpy() def prepare_output(scores): output = [[clip_id, clip_id+2, s] for (clip_id, s) in enumerate(scores)] return sorted(output, key=lambda result: result[2], reverse=True) def lbhd_predict(video_file): scores = videofile_to_scores(video_file, scoringModel) return prepare_output(scores) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str, default='../weight/ckpt_epoch_59_loss_0.3066582295561343.ckpt') parser.add_argument('--videofile', type=str, default=None) args = parser.parse_args() print('Arguments:', args) checkpoint = torch.load(args.checkpoint, map_location='cpu') scoringModel = checkpoint['model'] frames_per_clip = scoringModel.frames_per_clip scores = videofile_to_scores(args.videofile, scoringModel) print(scores) print(prepare_output(scores))