PraneshJs commited on
Commit
9ce33d8
·
verified ·
1 Parent(s): c715f39

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +74 -69
inference_2.py CHANGED
@@ -1,28 +1,24 @@
1
  import os
2
  import cv2
 
3
  import torch
4
  import numpy as np
 
5
  from onnx2pytorch import ConvertModel
6
  from models.TMC import ETMC
7
  from models import image
8
- import onnx
9
 
10
- # -------------------
11
- # Load ONNX -> PyTorch model for image modality
12
- # -------------------
13
- onnx_model_path = 'checkpoints/efficientnet.onnx'
14
- onnx_model = onnx.load(onnx_model_path)
15
- img_model = ConvertModel(onnx_model)
16
- img_model.eval()
17
 
18
- # -------------------
19
- # Set random seed for reproducibility
20
- # -------------------
21
  torch.manual_seed(42)
22
 
23
- # -------------------
24
- # Audio model configuration
25
- # -------------------
26
  audio_args = {
27
  'nb_samp': 64600,
28
  'first_conv': 1024,
@@ -32,14 +28,17 @@ audio_args = {
32
  'nb_fc_node': 1024,
33
  'gru_node': 1024,
34
  'nb_gru_layer': 3,
35
- 'nb_classes': 2
 
36
  }
37
 
38
- # -------------------
 
 
39
  # Load Audio Model
40
- # -------------------
41
  def load_audio_model():
42
- spec_model = image.RawNet(audio_args)
43
  ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
44
  spec_model.load_state_dict(ckpt['spec_encoder'], strict=True)
45
  spec_model.eval()
@@ -47,76 +46,82 @@ def load_audio_model():
47
 
48
  spec_model = load_audio_model()
49
 
50
- # -------------------
51
- # Preprocessing Functions
52
- # -------------------
 
 
 
 
 
 
 
 
 
 
 
 
53
  def preprocess_img(face):
54
  face = face / 255.0
55
  face = cv2.resize(face, (256, 256))
56
- face_tensor = torch.unsqueeze(torch.Tensor(face), dim=0)
57
- return face_tensor
58
 
59
  def preprocess_audio(audio_file):
60
- audio_tensor = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
61
- return audio_tensor
62
 
63
  def preprocess_video(input_video, n_frames=3):
64
- cap = cv2.VideoCapture(input_video)
65
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
- sample = np.linspace(0, total_frames-1, n_frames).astype(int)
67
-
68
  frames = []
69
- for i in range(total_frames):
70
- success = cap.grab()
71
- if i in sample:
72
- success, frame = cap.retrieve()
73
  if not success:
74
  continue
75
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
  frame = preprocess_img(frame)
77
  frames.append(frame)
78
- cap.release()
79
  return frames
80
 
81
- # -------------------
82
- # Prediction Functions
83
- # -------------------
 
 
 
 
 
 
 
 
 
 
 
84
  def deepfakes_image_predict(input_image):
85
  face = preprocess_img(input_image)
86
- with torch.no_grad():
87
- preds = img_model.forward(face).cpu().numpy().squeeze()
88
-
89
- if preds[0] > 0.5:
90
- score = round(preds[0] * 100, 3)
91
- return f"The image is REAL. Confidence score: {score}%"
92
  else:
93
- score = round(preds[1] * 100, 3)
94
- return f"The image is FAKE. Confidence score: {score}%"
95
 
96
  def deepfakes_video_predict(input_video):
97
- frames = preprocess_video(input_video)
98
- real_scores, fake_scores = [], []
99
-
100
- with torch.no_grad():
101
- for frame in frames:
102
- preds = img_model.forward(frame).cpu().numpy().squeeze()
103
- real_scores.append(preds[0])
104
- fake_scores.append(preds[1])
105
-
106
- real_mean = np.mean(real_scores)
107
- fake_mean = np.mean(fake_scores)
108
-
109
  if real_mean > 0.5:
110
- return f"The video is REAL. Confidence score: {round(real_mean*100, 3)}%"
111
- else:
112
- return f"The video is FAKE. Confidence score: {round(fake_mean*100, 3)}%"
113
-
114
- def deepfakes_spec_predict(input_audio):
115
- audio_tensor = preprocess_audio(input_audio)
116
- with torch.no_grad():
117
- preds = spec_model.forward(audio_tensor).cpu().numpy().squeeze()
118
-
119
- if preds[0] > 0.5:
120
- return "The audio is REAL."
121
  else:
122
- return "The audio is FAKE."
 
 
1
  import os
2
  import cv2
3
+ import onnx
4
  import torch
5
  import numpy as np
6
+ from types import SimpleNamespace
7
  from onnx2pytorch import ConvertModel
8
  from models.TMC import ETMC
9
  from models import image
 
10
 
11
+ # -----------------------------
12
+ # Load ONNX -> PyTorch safely
13
+ # -----------------------------
14
+ onnx_model = onnx.load('checkpoints/efficientnet.onnx')
15
+ pytorch_model = ConvertModel(onnx_model, strict=False)
 
 
16
 
 
 
 
17
  torch.manual_seed(42)
18
 
19
+ # -----------------------------
20
+ # Audio model arguments
21
+ # -----------------------------
22
  audio_args = {
23
  'nb_samp': 64600,
24
  'first_conv': 1024,
 
28
  'nb_fc_node': 1024,
29
  'gru_node': 1024,
30
  'nb_gru_layer': 3,
31
+ 'nb_classes': 2,
32
+ 'device': 'cpu'
33
  }
34
 
35
+ audio_args_obj = SimpleNamespace(**audio_args)
36
+
37
+ # -----------------------------
38
  # Load Audio Model
39
+ # -----------------------------
40
  def load_audio_model():
41
+ spec_model = image.RawNet(audio_args_obj)
42
  ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
43
  spec_model.load_state_dict(ckpt['spec_encoder'], strict=True)
44
  spec_model.eval()
 
46
 
47
  spec_model = load_audio_model()
48
 
49
+ # -----------------------------
50
+ # Load Image Model
51
+ # -----------------------------
52
+ def load_image_model():
53
+ rgb_encoder = pytorch_model
54
+ ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
55
+ rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict=True)
56
+ rgb_encoder.eval()
57
+ return rgb_encoder
58
+
59
+ img_model = load_image_model()
60
+
61
+ # -----------------------------
62
+ # Preprocessing functions
63
+ # -----------------------------
64
  def preprocess_img(face):
65
  face = face / 255.0
66
  face = cv2.resize(face, (256, 256))
67
+ face_pt = torch.unsqueeze(torch.Tensor(face), dim=0)
68
+ return face_pt
69
 
70
  def preprocess_audio(audio_file):
71
+ audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
72
+ return audio_pt
73
 
74
  def preprocess_video(input_video, n_frames=3):
75
+ v_cap = cv2.VideoCapture(input_video)
76
+ v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
77
+ sample = np.linspace(0, v_len-1, n_frames).astype(int)
 
78
  frames = []
79
+ for j in range(v_len):
80
+ success = v_cap.grab()
81
+ if j in sample:
82
+ success, frame = v_cap.retrieve()
83
  if not success:
84
  continue
85
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
  frame = preprocess_img(frame)
87
  frames.append(frame)
88
+ v_cap.release()
89
  return frames
90
 
91
+ # -----------------------------
92
+ # Inference functions
93
+ # -----------------------------
94
+ def deepfakes_spec_predict(input_audio):
95
+ audio = preprocess_audio(input_audio)
96
+ spec_grads = spec_model.forward(audio)
97
+ spec_grads_np = np.exp(spec_grads.cpu().detach().numpy().squeeze())
98
+ max_value = np.argmax(spec_grads_np)
99
+ if max_value > 0.5:
100
+ text2 = f"The audio is REAL."
101
+ else:
102
+ text2 = f"The audio is FAKE."
103
+ return text2
104
+
105
  def deepfakes_image_predict(input_image):
106
  face = preprocess_img(input_image)
107
+ img_grads = img_model.forward(face).cpu().detach().numpy().squeeze()
108
+ if img_grads[0] > 0.5:
109
+ text2 = f"The image is REAL. Confidence: {img_grads[0]*100:.3f}%"
 
 
 
110
  else:
111
+ text2 = f"The image is FAKE. Confidence: {img_grads[1]*100:.3f}%"
112
+ return text2
113
 
114
  def deepfakes_video_predict(input_video):
115
+ video_frames = preprocess_video(input_video)
116
+ real_list, fake_list = [], []
117
+ for face in video_frames:
118
+ img_grads = img_model.forward(face).cpu().detach().numpy().squeeze()
119
+ real_list.append(img_grads[0])
120
+ fake_list.append(img_grads[1])
121
+ real_mean = np.mean(real_list)
122
+ fake_mean = np.mean(fake_list)
 
 
 
 
123
  if real_mean > 0.5:
124
+ text2 = f"The video is REAL. Confidence: {real_mean*100:.3f}%"
 
 
 
 
 
 
 
 
 
 
125
  else:
126
+ text2 = f"The video is FAKE. Confidence: {fake_mean*100:.3f}%"
127
+ return text2