truens66 commited on
Commit
9797501
·
verified ·
1 Parent(s): fa179db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -80
app.py CHANGED
@@ -124,103 +124,194 @@ import numpy as np
124
  import mediapipe as mp
125
  from torchvision import models, transforms
126
  from tempfile import NamedTemporaryFile
127
-
128
- # Initialize MediaPipe Face Detection and Face Mesh
129
- mp_face_detection = mp.solutions.face_detection
130
- mp_face_mesh = mp.solutions.face_mesh
131
- face_detection = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
132
- face_mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1, min_detection_confidence=0.5)
133
-
134
- # Initialize ResNet-34 model with random weights
135
- def create_model():
136
- model = models.resnet34(pretrained=False)
137
- model.fc = torch.nn.Linear(model.fc.in_features, 2)
138
- return model
139
-
140
- model = create_model()
141
- model.eval()
142
-
143
- # Define transformation for face images
144
- transform = transforms.Compose([
145
- transforms.ToPILImage(),
146
- transforms.Resize((224, 224)),
147
- transforms.ToTensor(),
148
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
149
- ])
150
-
151
- def get_face_bbox(landmarks, frame_shape):
152
- h, w = frame_shape[:2]
153
- xs = [lm.x * w for lm in landmarks.landmark]
154
- ys = [lm.y * h for lm in landmarks.landmark]
155
- return int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))
156
-
157
- def process_video(video_path: str):
158
- cap = cv2.VideoCapture(video_path)
159
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
160
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
161
- fps = int(cap.get(cv2.CAP_PROP_FPS))
162
-
163
- output_path = video_path.replace(".mp4", "_processed.mp4")
164
- output_video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
165
-
166
- while cap.isOpened():
167
- ret, frame = cap.read()
168
- if not ret:
169
- break
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
172
 
173
- # Face detection
174
- results = face_detection.process(rgb_frame)
175
- if results.detections:
176
- for detection in results.detections:
177
- # Get face landmarks
178
- mesh_results = face_mesh.process(rgb_frame)
179
- if mesh_results.multi_face_landmarks:
180
- for face_landmarks in mesh_results.multi_face_landmarks:
181
- x_min, y_min, x_max, y_max = get_face_bbox(face_landmarks, frame.shape)
182
-
183
- face_crop = rgb_frame[y_min:y_max, x_min:x_max]
184
- if face_crop.size == 0:
185
- continue
186
-
187
- face_tensor = transform(face_crop).unsqueeze(0)
188
- with torch.no_grad():
189
- output = torch.softmax(model(face_tensor), dim=1)
190
- fake_confidence = output[0, 1].item() * 100
191
- label = "Fake" if fake_confidence > 50 else "Real"
192
- color = (0, 0, 255) if label == "Fake" else (0, 255, 0)
193
- label_text = f"{label} ({fake_confidence:.2f}%)"
194
-
195
- cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
196
- cv2.putText(frame, label_text, (x_min, y_min - 10),
197
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
198
-
199
- output_video.write(frame)
200
-
201
- cap.release()
202
- output_video.release()
203
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def gradio_interface(video_file):
 
206
  if video_file is None:
207
  return "Error: No video uploaded."
208
-
 
 
209
  with NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
210
  temp_file_path = temp_file.name
211
  with open(video_file, "rb") as uploaded_file:
212
  temp_file.write(uploaded_file.read())
213
-
214
- output_path = process_video(temp_file_path)
 
 
 
215
  return output_path
216
 
 
217
  iface = gr.Interface(
218
  fn=gradio_interface,
219
  inputs=gr.Video(label="Upload Video"),
220
  outputs=gr.Video(label="Processed Video"),
221
  title="Deepfake Detection",
222
- description="Upload a video to detect deepfakes using MediaPipe face detection and ResNet-34 model."
 
223
  )
224
 
225
  if __name__ == "__main__":
226
- iface.launch()
 
 
 
 
 
124
  import mediapipe as mp
125
  from torchvision import models, transforms
126
  from tempfile import NamedTemporaryFile
127
+ from pathlib import Path
128
+ import logging
129
+ from typing import Tuple, Optional
130
+
131
+ # Set up logging
132
+ logging.basicConfig(level=logging.INFO)
133
+ logger = logging.getLogger(__name__)
134
+
135
+ class DeepfakeDetector:
136
+ def __init__(self, detection_confidence: float = 0.5, max_faces: int = 1):
137
+ """Initialize the DeepfakeDetector with MediaPipe and ResNet model."""
138
+ self.mp_face_detection = mp.solutions.face_detection
139
+ self.mp_face_mesh = mp.solutions.face_mesh
140
+
141
+ # Initialize face detection and mesh
142
+ self.face_detection = self.mp_face_detection.FaceDetection(
143
+ model_selection=1,
144
+ min_detection_confidence=detection_confidence
145
+ )
146
+ self.face_mesh = self.mp_face_mesh.FaceMesh(
147
+ static_image_mode=False,
148
+ max_num_faces=max_faces,
149
+ min_detection_confidence=detection_confidence
150
+ )
151
+
152
+ # Initialize model and transform
153
+ self.model = self._create_model()
154
+ self.transform = self._create_transform()
155
+
156
+ @staticmethod
157
+ def _create_model() -> torch.nn.Module:
158
+ """Create and configure the ResNet model."""
159
+ model = models.resnet34(weights=None)
160
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
161
+ model.eval()
162
+ return model
163
+
164
+ @staticmethod
165
+ def _create_transform() -> transforms.Compose:
166
+ """Create the image transformation pipeline."""
167
+ return transforms.Compose([
168
+ transforms.ToPILImage(),
169
+ transforms.Resize((224, 224)),
170
+ transforms.ToTensor(),
171
+ transforms.Normalize(
172
+ mean=[0.485, 0.456, 0.406],
173
+ std=[0.229, 0.224, 0.225]
174
+ )
175
+ ])
176
+
177
+ def get_face_bbox(self, landmarks, frame_shape: Tuple[int, int]) -> Tuple[int, int, int, int]:
178
+ """Extract face bounding box from landmarks."""
179
+ h, w = frame_shape[:2]
180
+ xs = [lm.x * w for lm in landmarks.landmark]
181
+ ys = [lm.y * h for lm in landmarks.landmark]
182
+ return (
183
+ max(0, int(min(xs))),
184
+ max(0, int(min(ys))),
185
+ min(w, int(max(xs))),
186
+ min(h, int(max(ys)))
187
+ )
188
+
189
+ def process_frame(self, frame: np.ndarray) -> np.ndarray:
190
+ """Process a single frame to detect deepfakes."""
191
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
192
 
193
+ # Detect faces
194
+ detection_results = self.face_detection.process(rgb_frame)
195
+ if not detection_results.detections:
196
+ return frame
197
+
198
+ # Process each detected face
199
+ for detection in detection_results.detections:
200
+ mesh_results = self.face_mesh.process(rgb_frame)
201
+ if not mesh_results.multi_face_landmarks:
202
+ continue
203
+
204
+ for face_landmarks in mesh_results.multi_face_landmarks:
205
+ frame = self._analyze_face(frame, rgb_frame, face_landmarks)
206
+
207
+ return frame
208
+
209
+ def _analyze_face(self, frame: np.ndarray, rgb_frame: np.ndarray,
210
+ face_landmarks) -> np.ndarray:
211
+ """Analyze a single face and draw results on frame."""
212
+ # Get face bbox
213
+ x_min, y_min, x_max, y_max = self.get_face_bbox(
214
+ face_landmarks, frame.shape
215
+ )
216
+
217
+ # Crop and transform face
218
+ face_crop = rgb_frame[y_min:y_max, x_min:x_max]
219
+ if face_crop.size == 0:
220
+ return frame
221
+
222
+ # Run inference
223
+ try:
224
+ face_tensor = self.transform(face_crop).unsqueeze(0)
225
+ with torch.no_grad():
226
+ output = torch.softmax(self.model(face_tensor), dim=1)
227
+ fake_confidence = output[0, 1].item() * 100
228
+ except Exception as e:
229
+ logger.error(f"Error during inference: {str(e)}")
230
+ return frame
231
+
232
+ # Draw results
233
+ label = "Fake" if fake_confidence > 50 else "Real"
234
+ color = (0, 0, 255) if label == "Fake" else (0, 255, 0)
235
+ label_text = f"{label} ({fake_confidence:.2f}%)"
236
+
237
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
238
+ cv2.putText(frame, label_text, (x_min, y_min - 10),
239
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
240
+
241
+ return frame
242
+
243
+ def process_video(self, video_path: str) -> Optional[str]:
244
+ """Process a video file and return path to processed video."""
245
+ try:
246
+ cap = cv2.VideoCapture(video_path)
247
+ if not cap.isOpened():
248
+ logger.error("Error opening video file")
249
+ return None
250
+
251
+ # Get video properties
252
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
253
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
254
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
255
+
256
+ # Set up output video
257
+ output_path = str(Path(video_path).with_suffix('')) + "_processed.mp4"
258
+ output_video = cv2.VideoWriter(
259
+ output_path,
260
+ cv2.VideoWriter_fourcc(*'mp4v'),
261
+ fps,
262
+ (width, height)
263
+ )
264
+
265
+ # Process frames
266
+ while cap.isOpened():
267
+ ret, frame = cap.read()
268
+ if not ret:
269
+ break
270
+
271
+ processed_frame = self.process_frame(frame)
272
+ output_video.write(processed_frame)
273
+
274
+ # Clean up
275
+ cap.release()
276
+ output_video.release()
277
+
278
+ return output_path
279
+
280
+ except Exception as e:
281
+ logger.error(f"Error processing video: {str(e)}")
282
+ return None
283
 
284
  def gradio_interface(video_file):
285
+ """Gradio interface function."""
286
  if video_file is None:
287
  return "Error: No video uploaded."
288
+
289
+ detector = DeepfakeDetector()
290
+
291
  with NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
292
  temp_file_path = temp_file.name
293
  with open(video_file, "rb") as uploaded_file:
294
  temp_file.write(uploaded_file.read())
295
+
296
+ output_path = detector.process_video(temp_file_path)
297
+ if output_path is None:
298
+ return "Error processing video"
299
+
300
  return output_path
301
 
302
+ # Create Gradio interface
303
  iface = gr.Interface(
304
  fn=gradio_interface,
305
  inputs=gr.Video(label="Upload Video"),
306
  outputs=gr.Video(label="Processed Video"),
307
  title="Deepfake Detection",
308
+ description="Upload a video to detect deepfakes using MediaPipe face detection and ResNet-34 model.",
309
+ examples=[], # Add example videos here if available
310
  )
311
 
312
  if __name__ == "__main__":
313
+ iface.launch(
314
+ server_name="0.0.0.0",
315
+ share=True, # Set to True to create a public link
316
+ debug=True
317
+ )