File size: 3,785 Bytes
24615d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec35ab9
24615d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec35ab9
 
24615d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9313499
24615d9
 
 
 
9313499
24615d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import argparse
import os
from moviepy.editor import *
from PIL import Image
from lbhd.i2v import i2v_transform, load_weight, EMBED_DIM
from lbhd.batch_image_transforms import batch_transform_val

device = 'cuda' if torch.cuda.is_available() else 'cpu'

CWD = os.path.dirname(__file__)
WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt')
i2v = load_weight(WEIGHT_i2v)
i2v.to(device)
i2v_transform.to(device)

WEIGHT_scoring = os.path.join(CWD, '..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt')
checkpoint = torch.load(WEIGHT_scoring, map_location=device)
scoringModel = checkpoint['model']


def sample_clips(x, frames_per_clip: int):
    print("x shape", x.shape, "frames_per_clip", frames_per_clip)
    x = torch.stack( [ x[i:i+frames_per_clip] for i in range(len(x) - frames_per_clip + 1) ] )
    return x

def sample_frames(file_name, fps):
    import cv2
    # Opens the Video file
    cap= cv2.VideoCapture(file_name)
    i=0
    frames = []
    while(cap.isOpened()):
        ret, frame = cap.read()
        if ret == False:
            break
        if i % fps == 0:
            frames.append(Image.fromarray(frame))
        i+=1
    cap.release()
    cv2.destroyAllWindows()
    return frames

def videofile_to_frames(filename, sample_every_second=True):
    clip = VideoFileClip(filename)
    frames = clip.iter_frames()
    frames = [Image.fromarray(frame) for index, frame in enumerate(frames) if index % int(clip.fps) == 0]
    print('clip.fps', clip.fps, 'Number of frames in video is:', len(frames))
    return frames


def frames_to_vectors(frames, frames_per_clip, use_frame_diff:bool=False, transform=batch_transform_val, i2v=i2v, i2v_transform=i2v_transform):
    x = torch.stack(transform(frames))
    x = sample_clips(x, frames_per_clip).to(device)             # x.size = (n_clips, frames_per_clip, 3, 224, 224)
    n_clips = x.size(0)
    x = x.view(-1,3,224,224)                                    # (n_clips * frames_per_clip, 3, 224 ,224)
    with torch.no_grad():
        x, _ = i2v(i2v_transform(x))                            # (n_clips * frames_per_clip, EMBED_DIM)
        x = x.view(n_clips, -1, EMBED_DIM)                      # (n_clips, frames_per_clip, EMBED_DIM)
        if use_frame_diff:
            x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
        x = x.view(n_clips, -1)                                 # (n_clips, frames_per_clip * EMBED_DIM)
    return x

def videofile_to_scores(videofile, model):
    frames = videofile_to_frames(videofile)
    x = frames_to_vectors(frames, model.frames_per_clip)
    with torch.no_grad():
        return torch.sigmoid(model(x)).squeeze().cpu().numpy()

def frames_to_scores(frames, model, use_frame_diff):
    x = frames_to_vectors(frames, model.frames_per_clip, use_frame_diff)
    with torch.no_grad():
        return torch.sigmoid(model(x)).squeeze().cpu().numpy()

def prepare_output(scores):
    output = [[clip_id, clip_id+2, s] for (clip_id, s) in enumerate(scores)]
    return sorted(output, key=lambda result: result[2], reverse=True)

def lbhd_predict(video_file):
    scores = videofile_to_scores(video_file, scoringModel)
    return prepare_output(scores)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, default='../weight/ckpt_epoch_59_loss_0.3066582295561343.ckpt')
    parser.add_argument('--videofile', type=str, default=None)
    args = parser.parse_args()
    print('Arguments:', args)

    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    scoringModel = checkpoint['model']
    frames_per_clip = scoringModel.frames_per_clip
    scores = videofile_to_scores(args.videofile, scoringModel)
    print(scores)
    print(prepare_output(scores))