PraneshJs commited on
Commit
064f97d
·
verified ·
1 Parent(s): 996bc38

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +167 -63
inference_2.py CHANGED
@@ -1,15 +1,23 @@
1
  import os
2
  import cv2
 
3
  import torch
 
4
  import numpy as np
5
- from onnx import load as onnx_load
 
 
 
6
  from onnx2pytorch import ConvertModel
7
- from models import image # Your RawNet audio model
8
 
9
- # Set seed for reproducibility
 
 
 
10
  torch.manual_seed(42)
11
 
12
- # Audio args for RawNet
 
13
  audio_args = {
14
  'nb_samp': 64600,
15
  'first_conv': 1024,
@@ -19,48 +27,155 @@ audio_args = {
19
  'nb_fc_node': 1024,
20
  'gru_node': 1024,
21
  'nb_gru_layer': 3,
22
- 'nb_classes': 2,
23
- 'device': 'cpu',
24
- 'pretrained_audio_encoder': False
25
  }
26
 
27
- # Convert audio_args dict to a namespace object
28
- from types import SimpleNamespace
29
- audio_args_obj = SimpleNamespace(**audio_args)
30
 
31
- # Load ONNX → PyTorch model for images
32
- onnx_model = onnx_load("checkpoints/efficientnet.onnx")
33
- img_model = ConvertModel(onnx_model) # do NOT use strict=True (not supported)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Load Audio model
36
- spec_model = image.RawNet(audio_args_obj)
37
 
38
- # Ensure models are in eval mode
39
- img_model.eval()
40
- spec_model.eval()
41
-
42
- # -------------------------
43
- # Preprocessing functions
44
- # -------------------------
45
  def preprocess_img(face):
46
- face = face / 255.0
47
  face = cv2.resize(face, (256, 256))
48
- face_tensor = torch.unsqueeze(torch.Tensor(face), dim=0)
49
- return face_tensor
 
50
 
51
  def preprocess_audio(audio_file):
52
- audio_tensor = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
53
- return audio_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- def preprocess_video(input_video, n_frames=3):
56
  v_cap = cv2.VideoCapture(input_video)
57
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
- sample = np.linspace(0, v_len - 1, n_frames).astype(int)
59
- frames = []
60
 
 
 
 
 
 
 
 
 
61
  for j in range(v_len):
62
  success = v_cap.grab()
63
  if j in sample:
 
64
  success, frame = v_cap.retrieve()
65
  if not success:
66
  continue
@@ -70,43 +185,32 @@ def preprocess_video(input_video, n_frames=3):
70
  v_cap.release()
71
  return frames
72
 
73
- # -------------------------
74
- # Prediction functions
75
- # -------------------------
76
- def deepfakes_spec_predict(input_audio):
77
- audio_tensor = preprocess_audio(input_audio)
78
- spec_grads = spec_model.forward(audio_tensor)
79
- spec_grads_np = np.squeeze(spec_grads.cpu().detach().numpy())
80
 
81
- if spec_grads_np[0] > 0.5:
82
- return "The audio is REAL."
83
- else:
84
- return "The audio is FAKE."
85
-
86
- def deepfakes_image_predict(input_image):
87
- face_tensor = preprocess_img(input_image)
88
- img_grads = img_model.forward(face_tensor)
89
- img_grads_np = np.squeeze(img_grads.cpu().detach().numpy())
90
 
91
- if img_grads_np[0] > 0.5:
92
- return f"The image is REAL. Confidence score: {round(img_grads_np[0]*100,2)}%"
93
- else:
94
- return f"The image is FAKE. Confidence score: {round(img_grads_np[1]*100,2)}%"
95
 
96
- def deepfakes_video_predict(input_video):
97
- frames = preprocess_video(input_video)
98
- real_list, fake_list = [], []
 
 
99
 
100
- for frame in frames:
101
- img_grads = img_model.forward(frame)
102
- img_grads_np = np.squeeze(img_grads.cpu().detach().numpy())
103
- real_list.append(img_grads_np[0])
104
- fake_list.append(img_grads_np[1])
105
 
106
- real_mean = np.mean(real_list)
107
- fake_mean = np.mean(fake_list)
 
108
 
109
- if real_mean > 0.5:
110
- return f"The video is REAL. Confidence: {round(real_mean*100,2)}%"
111
  else:
112
- return f"The video is FAKE. Confidence: {round(fake_mean*100,2)}%"
 
 
 
 
 
1
  import os
2
  import cv2
3
+ import onnx
4
  import torch
5
+ import argparse
6
  import numpy as np
7
+ import torch.nn as nn
8
+ from models.TMC import ETMC
9
+ from models import image
10
+
11
  from onnx2pytorch import ConvertModel
 
12
 
13
+ onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
+ pytorch_model = ConvertModel(onnx_model)
15
+
16
+ #Set random seed for reproducibility.
17
  torch.manual_seed(42)
18
 
19
+
20
+ # Define the audio_args dictionary
21
  audio_args = {
22
  'nb_samp': 64600,
23
  'first_conv': 1024,
 
27
  'nb_fc_node': 1024,
28
  'gru_node': 1024,
29
  'nb_gru_layer': 3,
30
+ 'nb_classes': 2
 
 
31
  }
32
 
 
 
 
33
 
34
+ def get_args(parser):
35
+ parser.add_argument("--batch_size", type=int, default=8)
36
+ parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
37
+ parser.add_argument("--LOAD_SIZE", type=int, default=256)
38
+ parser.add_argument("--FINE_SIZE", type=int, default=224)
39
+ parser.add_argument("--dropout", type=float, default=0.2)
40
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
41
+ parser.add_argument("--hidden", nargs="*", type=int, default=[])
42
+ parser.add_argument("--hidden_sz", type=int, default=768)
43
+ parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
44
+ parser.add_argument("--img_hidden_sz", type=int, default=1024)
45
+ parser.add_argument("--include_bn", type=int, default=True)
46
+ parser.add_argument("--lr", type=float, default=1e-4)
47
+ parser.add_argument("--lr_factor", type=float, default=0.3)
48
+ parser.add_argument("--lr_patience", type=int, default=10)
49
+ parser.add_argument("--max_epochs", type=int, default=500)
50
+ parser.add_argument("--n_workers", type=int, default=12)
51
+ parser.add_argument("--name", type=str, default="MMDF")
52
+ parser.add_argument("--num_image_embeds", type=int, default=1)
53
+ parser.add_argument("--patience", type=int, default=20)
54
+ parser.add_argument("--savedir", type=str, default="./savepath/")
55
+ parser.add_argument("--seed", type=int, default=1)
56
+ parser.add_argument("--n_classes", type=int, default=2)
57
+ parser.add_argument("--annealing_epoch", type=int, default=10)
58
+ parser.add_argument("--device", type=str, default='cpu')
59
+ parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
60
+ parser.add_argument("--freeze_image_encoder", type=bool, default = False)
61
+ parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
62
+ parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
63
+ parser.add_argument("--augment_dataset", type = bool, default = True)
64
+
65
+ for key, value in audio_args.items():
66
+ parser.add_argument(f"--{key}", type=type(value), default=value)
67
+
68
+ def model_summary(args):
69
+ '''Prints the model summary.'''
70
+ model = ETMC(args)
71
+
72
+ for name, layer in model.named_modules():
73
+ print(name, layer)
74
+
75
+ def load_multimodal_model(args):
76
+ '''Load multimodal model'''
77
+ model = ETMC(args)
78
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
79
+ model.load_state_dict(ckpt, strict = True)
80
+ model.eval()
81
+ return model
82
+
83
+ def load_img_modality_model(args):
84
+ '''Loads image modality model.'''
85
+ rgb_encoder = pytorch_model
86
+
87
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
+ rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
+ rgb_encoder.eval()
90
+ return rgb_encoder
91
+
92
+ def load_spec_modality_model(args):
93
+ spec_encoder = image.RawNet(args)
94
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
95
+ spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
96
+ spec_encoder.eval()
97
+ return spec_encoder
98
+
99
+
100
+ #Load models.
101
+ parser = argparse.ArgumentParser(description="Inference models")
102
+ get_args(parser)
103
+ args, remaining_args = parser.parse_known_args()
104
+ assert remaining_args == [], remaining_args
105
+
106
+ spec_model = load_spec_modality_model(args)
107
+
108
+ img_model = load_img_modality_model(args)
109
 
 
 
110
 
 
 
 
 
 
 
 
111
  def preprocess_img(face):
112
+ face = face / 255
113
  face = cv2.resize(face, (256, 256))
114
+ # face = face.transpose(2, 0, 1) #(W, H, C) -> (C, W, H)
115
+ face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
116
+ return face_pt
117
 
118
  def preprocess_audio(audio_file):
119
+ audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
120
+ return audio_pt
121
+
122
+ def deepfakes_spec_predict(input_audio):
123
+ x, _ = input_audio
124
+ audio = preprocess_audio(x)
125
+ spec_grads = spec_model.forward(audio)
126
+ spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
127
+
128
+ # multimodal_grads = multimodal.spec_depth[0].forward(spec_grads)
129
+
130
+ # out = nn.Softmax()(multimodal_grads)
131
+ # max = torch.argmax(out, dim = -1) #Index of the max value in the tensor.
132
+ # max_value = out[max] #Actual value of the tensor.
133
+ max_value = np.argmax(spec_grads_inv)
134
+
135
+ if max_value > 0.5:
136
+ preds = round(100 - (max_value*100), 3)
137
+ text2 = f"The audio is REAL."
138
+
139
+ else:
140
+ preds = round(max_value*100, 3)
141
+ text2 = f"The audio is FAKE."
142
+
143
+ return text2
144
+
145
+ def deepfakes_image_predict(input_image):
146
+ face = preprocess_img(input_image)
147
+ print(f"Face shape is: {face.shape}")
148
+ img_grads = img_model.forward(face)
149
+ img_grads = img_grads.cpu().detach().numpy()
150
+ img_grads_np = np.squeeze(img_grads)
151
+
152
+ if img_grads_np[0] > 0.5:
153
+ preds = round(img_grads_np[0] * 100, 3)
154
+ text2 = f"The image is REAL. \nConfidence score is: {preds}"
155
+
156
+ else:
157
+ preds = round(img_grads_np[1] * 100, 3)
158
+ text2 = f"The image is FAKE. \nConfidence score is: {preds}"
159
+
160
+ return text2
161
+
162
 
163
+ def preprocess_video(input_video, n_frames = 3):
164
  v_cap = cv2.VideoCapture(input_video)
165
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
166
 
167
+ # Pick 'n_frames' evenly spaced frames to sample
168
+ if n_frames is None:
169
+ sample = np.arange(0, v_len)
170
+ else:
171
+ sample = np.linspace(0, v_len - 1, n_frames).astype(int)
172
+
173
+ #Loop through frames.
174
+ frames = []
175
  for j in range(v_len):
176
  success = v_cap.grab()
177
  if j in sample:
178
+ # Load frame
179
  success, frame = v_cap.retrieve()
180
  if not success:
181
  continue
 
185
  v_cap.release()
186
  return frames
187
 
 
 
 
 
 
 
 
188
 
189
+ def deepfakes_video_predict(input_video):
190
+ '''Perform inference on a video.'''
191
+ video_frames = preprocess_video(input_video)
192
+ real_faces_list = []
193
+ fake_faces_list = []
 
 
 
 
194
 
195
+ for face in video_frames:
196
+ # face = preprocess_img(face)
 
 
197
 
198
+ img_grads = img_model.forward(face)
199
+ img_grads = img_grads.cpu().detach().numpy()
200
+ img_grads_np = np.squeeze(img_grads)
201
+ real_faces_list.append(img_grads_np[0])
202
+ fake_faces_list.append(img_grads_np[1])
203
 
204
+ real_faces_mean = np.mean(real_faces_list)
205
+ fake_faces_mean = np.mean(fake_faces_list)
 
 
 
206
 
207
+ if real_faces_mean > 0.5:
208
+ preds = round(real_faces_mean * 100, 3)
209
+ text2 = f"The video is REAL. \nConfidence score is: {preds}%"
210
 
 
 
211
  else:
212
+ preds = round(fake_faces_mean * 100, 3)
213
+ text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
214
+
215
+ return text2
216
+