PraneshJs commited on
Commit
c715f39
·
verified ·
1 Parent(s): e2d9049

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +85 -179
inference_2.py CHANGED
@@ -1,23 +1,28 @@
1
  import os
2
  import cv2
3
- import onnx
4
  import torch
5
- import argparse
6
  import numpy as np
7
- import torch.nn as nn
8
  from models.TMC import ETMC
9
  from models import image
 
10
 
11
- from onnx2pytorch import ConvertModel
12
-
13
- onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
- pytorch_model = ConvertModel(onnx_model)
15
-
16
- #Set random seed for reproducibility.
 
 
 
 
 
17
  torch.manual_seed(42)
18
 
19
-
20
- # Define the audio_args dictionary
 
21
  audio_args = {
22
  'nb_samp': 64600,
23
  'first_conv': 1024,
@@ -30,187 +35,88 @@ audio_args = {
30
  'nb_classes': 2
31
  }
32
 
33
-
34
- def get_args(parser):
35
- parser.add_argument("--batch_size", type=int, default=8)
36
- parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
37
- parser.add_argument("--LOAD_SIZE", type=int, default=256)
38
- parser.add_argument("--FINE_SIZE", type=int, default=224)
39
- parser.add_argument("--dropout", type=float, default=0.2)
40
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
41
- parser.add_argument("--hidden", nargs="*", type=int, default=[])
42
- parser.add_argument("--hidden_sz", type=int, default=768)
43
- parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
44
- parser.add_argument("--img_hidden_sz", type=int, default=1024)
45
- parser.add_argument("--include_bn", type=int, default=True)
46
- parser.add_argument("--lr", type=float, default=1e-4)
47
- parser.add_argument("--lr_factor", type=float, default=0.3)
48
- parser.add_argument("--lr_patience", type=int, default=10)
49
- parser.add_argument("--max_epochs", type=int, default=500)
50
- parser.add_argument("--n_workers", type=int, default=12)
51
- parser.add_argument("--name", type=str, default="MMDF")
52
- parser.add_argument("--num_image_embeds", type=int, default=1)
53
- parser.add_argument("--patience", type=int, default=20)
54
- parser.add_argument("--savedir", type=str, default="./savepath/")
55
- parser.add_argument("--seed", type=int, default=1)
56
- parser.add_argument("--n_classes", type=int, default=2)
57
- parser.add_argument("--annealing_epoch", type=int, default=10)
58
- parser.add_argument("--device", type=str, default='cpu')
59
- parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
60
- parser.add_argument("--freeze_image_encoder", type=bool, default = False)
61
- parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
62
- parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
63
- parser.add_argument("--augment_dataset", type = bool, default = True)
64
-
65
- for key, value in audio_args.items():
66
- parser.add_argument(f"--{key}", type=type(value), default=value)
67
-
68
- def model_summary(args):
69
- '''Prints the model summary.'''
70
- model = ETMC(args)
71
-
72
- for name, layer in model.named_modules():
73
- print(name, layer)
74
-
75
- def load_multimodal_model(args):
76
- '''Load multimodal model'''
77
- model = ETMC(args)
78
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
79
- model.load_state_dict(ckpt, strict = True)
80
- model.eval()
81
- return model
82
-
83
- def load_img_modality_model(args):
84
- '''Loads image modality model.'''
85
- rgb_encoder = pytorch_model
86
-
87
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
- rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
- rgb_encoder.eval()
90
- return rgb_encoder
91
-
92
- def load_spec_modality_model(args):
93
- spec_encoder = image.RawNet(args)
94
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
95
- spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
96
- spec_encoder.eval()
97
- return spec_encoder
98
-
99
-
100
- #Load models.
101
- parser = argparse.ArgumentParser(description="Inference models")
102
- get_args(parser)
103
- args, remaining_args = parser.parse_known_args()
104
- assert remaining_args == [], remaining_args
105
-
106
- spec_model = load_spec_modality_model(args)
107
-
108
- img_model = load_img_modality_model(args)
109
-
110
-
111
  def preprocess_img(face):
112
- face = face / 255
113
  face = cv2.resize(face, (256, 256))
114
- # face = face.transpose(2, 0, 1) #(W, H, C) -> (C, W, H)
115
- face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
116
- return face_pt
117
 
118
  def preprocess_audio(audio_file):
119
- audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
120
- return audio_pt
121
-
122
- def deepfakes_spec_predict(input_audio):
123
- x, _ = input_audio
124
- audio = preprocess_audio(x)
125
- spec_grads = spec_model.forward(audio)
126
- spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
127
-
128
- # multimodal_grads = multimodal.spec_depth[0].forward(spec_grads)
129
-
130
- # out = nn.Softmax()(multimodal_grads)
131
- # max = torch.argmax(out, dim = -1) #Index of the max value in the tensor.
132
- # max_value = out[max] #Actual value of the tensor.
133
- max_value = np.argmax(spec_grads_inv)
134
-
135
- if max_value > 0.5:
136
- preds = round(100 - (max_value*100), 3)
137
- text2 = f"The audio is REAL."
138
-
139
- else:
140
- preds = round(max_value*100, 3)
141
- text2 = f"The audio is FAKE."
142
-
143
- return text2
144
-
145
- def deepfakes_image_predict(input_image):
146
- face = preprocess_img(input_image)
147
- print(f"Face shape is: {face.shape}")
148
- img_grads = img_model.forward(face)
149
- img_grads = img_grads.cpu().detach().numpy()
150
- img_grads_np = np.squeeze(img_grads)
151
-
152
- if img_grads_np[0] > 0.5:
153
- preds = round(img_grads_np[0] * 100, 3)
154
- text2 = f"The image is REAL. \nConfidence score is: {preds}"
155
-
156
- else:
157
- preds = round(img_grads_np[1] * 100, 3)
158
- text2 = f"The image is FAKE. \nConfidence score is: {preds}"
159
-
160
- return text2
161
-
162
-
163
- def preprocess_video(input_video, n_frames = 3):
164
- v_cap = cv2.VideoCapture(input_video)
165
- v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
166
-
167
- # Pick 'n_frames' evenly spaced frames to sample
168
- if n_frames is None:
169
- sample = np.arange(0, v_len)
170
- else:
171
- sample = np.linspace(0, v_len - 1, n_frames).astype(int)
172
-
173
- #Loop through frames.
174
  frames = []
175
- for j in range(v_len):
176
- success = v_cap.grab()
177
- if j in sample:
178
- # Load frame
179
- success, frame = v_cap.retrieve()
180
  if not success:
181
  continue
182
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
183
  frame = preprocess_img(frame)
184
  frames.append(frame)
185
- v_cap.release()
186
  return frames
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def deepfakes_video_predict(input_video):
190
- '''Perform inference on a video.'''
191
- video_frames = preprocess_video(input_video)
192
- real_faces_list = []
193
- fake_faces_list = []
194
-
195
- for face in video_frames:
196
- # face = preprocess_img(face)
197
-
198
- img_grads = img_model.forward(face)
199
- img_grads = img_grads.cpu().detach().numpy()
200
- img_grads_np = np.squeeze(img_grads)
201
- real_faces_list.append(img_grads_np[0])
202
- fake_faces_list.append(img_grads_np[1])
203
-
204
- real_faces_mean = np.mean(real_faces_list)
205
- fake_faces_mean = np.mean(fake_faces_list)
206
-
207
- if real_faces_mean > 0.5:
208
- preds = round(real_faces_mean * 100, 3)
209
- text2 = f"The video is REAL. \nConfidence score is: {preds}%"
210
-
211
  else:
212
- preds = round(fake_faces_mean * 100, 3)
213
- text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
214
-
215
- return text2
216
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import cv2
 
3
  import torch
 
4
  import numpy as np
5
+ from onnx2pytorch import ConvertModel
6
  from models.TMC import ETMC
7
  from models import image
8
+ import onnx
9
 
10
+ # -------------------
11
+ # Load ONNX -> PyTorch model for image modality
12
+ # -------------------
13
+ onnx_model_path = 'checkpoints/efficientnet.onnx'
14
+ onnx_model = onnx.load(onnx_model_path)
15
+ img_model = ConvertModel(onnx_model)
16
+ img_model.eval()
17
+
18
+ # -------------------
19
+ # Set random seed for reproducibility
20
+ # -------------------
21
  torch.manual_seed(42)
22
 
23
+ # -------------------
24
+ # Audio model configuration
25
+ # -------------------
26
  audio_args = {
27
  'nb_samp': 64600,
28
  'first_conv': 1024,
 
35
  'nb_classes': 2
36
  }
37
 
38
+ # -------------------
39
+ # Load Audio Model
40
+ # -------------------
41
+ def load_audio_model():
42
+ spec_model = image.RawNet(audio_args)
43
+ ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
44
+ spec_model.load_state_dict(ckpt['spec_encoder'], strict=True)
45
+ spec_model.eval()
46
+ return spec_model
47
+
48
+ spec_model = load_audio_model()
49
+
50
+ # -------------------
51
+ # Preprocessing Functions
52
+ # -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def preprocess_img(face):
54
+ face = face / 255.0
55
  face = cv2.resize(face, (256, 256))
56
+ face_tensor = torch.unsqueeze(torch.Tensor(face), dim=0)
57
+ return face_tensor
 
58
 
59
  def preprocess_audio(audio_file):
60
+ audio_tensor = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
61
+ return audio_tensor
62
+
63
+ def preprocess_video(input_video, n_frames=3):
64
+ cap = cv2.VideoCapture(input_video)
65
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ sample = np.linspace(0, total_frames-1, n_frames).astype(int)
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  frames = []
69
+ for i in range(total_frames):
70
+ success = cap.grab()
71
+ if i in sample:
72
+ success, frame = cap.retrieve()
 
73
  if not success:
74
  continue
75
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
  frame = preprocess_img(frame)
77
  frames.append(frame)
78
+ cap.release()
79
  return frames
80
 
81
+ # -------------------
82
+ # Prediction Functions
83
+ # -------------------
84
+ def deepfakes_image_predict(input_image):
85
+ face = preprocess_img(input_image)
86
+ with torch.no_grad():
87
+ preds = img_model.forward(face).cpu().numpy().squeeze()
88
+
89
+ if preds[0] > 0.5:
90
+ score = round(preds[0] * 100, 3)
91
+ return f"The image is REAL. Confidence score: {score}%"
92
+ else:
93
+ score = round(preds[1] * 100, 3)
94
+ return f"The image is FAKE. Confidence score: {score}%"
95
 
96
  def deepfakes_video_predict(input_video):
97
+ frames = preprocess_video(input_video)
98
+ real_scores, fake_scores = [], []
99
+
100
+ with torch.no_grad():
101
+ for frame in frames:
102
+ preds = img_model.forward(frame).cpu().numpy().squeeze()
103
+ real_scores.append(preds[0])
104
+ fake_scores.append(preds[1])
105
+
106
+ real_mean = np.mean(real_scores)
107
+ fake_mean = np.mean(fake_scores)
108
+
109
+ if real_mean > 0.5:
110
+ return f"The video is REAL. Confidence score: {round(real_mean*100, 3)}%"
 
 
 
 
 
 
 
111
  else:
112
+ return f"The video is FAKE. Confidence score: {round(fake_mean*100, 3)}%"
 
 
 
113
 
114
+ def deepfakes_spec_predict(input_audio):
115
+ audio_tensor = preprocess_audio(input_audio)
116
+ with torch.no_grad():
117
+ preds = spec_model.forward(audio_tensor).cpu().numpy().squeeze()
118
+
119
+ if preds[0] > 0.5:
120
+ return "The audio is REAL."
121
+ else:
122
+ return "The audio is FAKE."