truens66 commited on
Commit
8c6a131
·
verified ·
1 Parent(s): 35709e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -171
app.py CHANGED
@@ -124,194 +124,103 @@ import numpy as np
124
  import mediapipe as mp
125
  from torchvision import models, transforms
126
  from tempfile import NamedTemporaryFile
127
- from pathlib import Path
128
- import logging
129
- from typing import Tuple, Optional
130
-
131
- # Set up logging
132
- logging.basicConfig(level=logging.INFO)
133
- logger = logging.getLogger(__name__)
134
-
135
- class DeepfakeDetector:
136
- def __init__(self, detection_confidence: float = 0.5, max_faces: int = 1):
137
- """Initialize the DeepfakeDetector with MediaPipe and ResNet model."""
138
- self.mp_face_detection = mp.solutions.face_detection
139
- self.mp_face_mesh = mp.solutions.face_mesh
140
-
141
- # Initialize face detection and mesh
142
- self.face_detection = self.mp_face_detection.FaceDetection(
143
- model_selection=1,
144
- min_detection_confidence=detection_confidence
145
- )
146
- self.face_mesh = self.mp_face_mesh.FaceMesh(
147
- static_image_mode=False,
148
- max_num_faces=max_faces,
149
- min_detection_confidence=detection_confidence
150
- )
151
-
152
- # Initialize model and transform
153
- self.model = self._create_model()
154
- self.transform = self._create_transform()
155
-
156
- @staticmethod
157
- def _create_model() -> torch.nn.Module:
158
- """Create and configure the ResNet model."""
159
- model = models.resnet34(weights=None)
160
- model.fc = torch.nn.Linear(model.fc.in_features, 2)
161
- model.eval()
162
- return model
163
-
164
- @staticmethod
165
- def _create_transform() -> transforms.Compose:
166
- """Create the image transformation pipeline."""
167
- return transforms.Compose([
168
- transforms.ToPILImage(),
169
- transforms.Resize((224, 224)),
170
- transforms.ToTensor(),
171
- transforms.Normalize(
172
- mean=[0.485, 0.456, 0.406],
173
- std=[0.229, 0.224, 0.225]
174
- )
175
- ])
176
-
177
- def get_face_bbox(self, landmarks, frame_shape: Tuple[int, int]) -> Tuple[int, int, int, int]:
178
- """Extract face bounding box from landmarks."""
179
- h, w = frame_shape[:2]
180
- xs = [lm.x * w for lm in landmarks.landmark]
181
- ys = [lm.y * h for lm in landmarks.landmark]
182
- return (
183
- max(0, int(min(xs))),
184
- max(0, int(min(ys))),
185
- min(w, int(max(xs))),
186
- min(h, int(max(ys)))
187
- )
188
-
189
- def process_frame(self, frame: np.ndarray) -> np.ndarray:
190
- """Process a single frame to detect deepfakes."""
191
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
192
 
193
- # Detect faces
194
- detection_results = self.face_detection.process(rgb_frame)
195
- if not detection_results.detections:
196
- return frame
197
-
198
- # Process each detected face
199
- for detection in detection_results.detections:
200
- mesh_results = self.face_mesh.process(rgb_frame)
201
- if not mesh_results.multi_face_landmarks:
202
- continue
203
-
204
- for face_landmarks in mesh_results.multi_face_landmarks:
205
- frame = self._analyze_face(frame, rgb_frame, face_landmarks)
206
-
207
- return frame
208
-
209
- def _analyze_face(self, frame: np.ndarray, rgb_frame: np.ndarray,
210
- face_landmarks) -> np.ndarray:
211
- """Analyze a single face and draw results on frame."""
212
- # Get face bbox
213
- x_min, y_min, x_max, y_max = self.get_face_bbox(
214
- face_landmarks, frame.shape
215
- )
216
-
217
- # Crop and transform face
218
- face_crop = rgb_frame[y_min:y_max, x_min:x_max]
219
- if face_crop.size == 0:
220
- return frame
221
-
222
- # Run inference
223
- try:
224
- face_tensor = self.transform(face_crop).unsqueeze(0)
225
- with torch.no_grad():
226
- output = torch.softmax(self.model(face_tensor), dim=1)
227
- fake_confidence = output[0, 1].item() * 100
228
- except Exception as e:
229
- logger.error(f"Error during inference: {str(e)}")
230
- return frame
231
-
232
- # Draw results
233
- label = "Fake" if fake_confidence > 50 else "Real"
234
- color = (0, 0, 255) if label == "Fake" else (0, 255, 0)
235
- label_text = f"{label} ({fake_confidence:.2f}%)"
236
-
237
- cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
238
- cv2.putText(frame, label_text, (x_min, y_min - 10),
239
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
240
-
241
- return frame
242
-
243
- def process_video(self, video_path: str) -> Optional[str]:
244
- """Process a video file and return path to processed video."""
245
- try:
246
- cap = cv2.VideoCapture(video_path)
247
- if not cap.isOpened():
248
- logger.error("Error opening video file")
249
- return None
250
-
251
- # Get video properties
252
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
253
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
254
- fps = int(cap.get(cv2.CAP_PROP_FPS))
255
-
256
- # Set up output video
257
- output_path = str(Path(video_path).with_suffix('')) + "_processed.mp4"
258
- output_video = cv2.VideoWriter(
259
- output_path,
260
- cv2.VideoWriter_fourcc(*'mp4v'),
261
- fps,
262
- (width, height)
263
- )
264
-
265
- # Process frames
266
- while cap.isOpened():
267
- ret, frame = cap.read()
268
- if not ret:
269
- break
270
-
271
- processed_frame = self.process_frame(frame)
272
- output_video.write(processed_frame)
273
-
274
- # Clean up
275
- cap.release()
276
- output_video.release()
277
-
278
- return output_path
279
-
280
- except Exception as e:
281
- logger.error(f"Error processing video: {str(e)}")
282
- return None
283
 
284
  def gradio_interface(video_file):
285
- """Gradio interface function."""
286
  if video_file is None:
287
  return "Error: No video uploaded."
288
-
289
- detector = DeepfakeDetector()
290
-
291
  with NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
292
  temp_file_path = temp_file.name
293
  with open(video_file, "rb") as uploaded_file:
294
  temp_file.write(uploaded_file.read())
295
-
296
- output_path = detector.process_video(temp_file_path)
297
- if output_path is None:
298
- return "Error processing video"
299
-
300
  return output_path
301
 
302
- # Create Gradio interface
303
  iface = gr.Interface(
304
  fn=gradio_interface,
305
  inputs=gr.Video(label="Upload Video"),
306
  outputs=gr.Video(label="Processed Video"),
307
  title="Deepfake Detection",
308
- description="Upload a video to detect deepfakes",
309
- examples=[], # Add example videos here if available
310
  )
311
 
312
  if __name__ == "__main__":
313
- iface.launch(
314
- server_name="0.0.0.0",
315
- share=True, # Set to True to create a public link
316
- debug=True
317
- )
 
124
  import mediapipe as mp
125
  from torchvision import models, transforms
126
  from tempfile import NamedTemporaryFile
127
+
128
+ # Initialize MediaPipe Face Detection and Face Mesh
129
+ mp_face_detection = mp.solutions.face_detection
130
+ mp_face_mesh = mp.solutions.face_mesh
131
+ face_detection = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
132
+ face_mesh = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1, min_detection_confidence=0.5)
133
+
134
+ # Initialize ResNet-34 model with random weights
135
+ def create_model():
136
+ model = models.resnet34(pretrained=False)
137
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
138
+ return model
139
+
140
+ model = create_model()
141
+ model.eval()
142
+
143
+ # Define transformation for face images
144
+ transform = transforms.Compose([
145
+ transforms.ToPILImage(),
146
+ transforms.Resize((224, 224)),
147
+ transforms.ToTensor(),
148
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
149
+ ])
150
+
151
+ def get_face_bbox(landmarks, frame_shape):
152
+ h, w = frame_shape[:2]
153
+ xs = [lm.x * w for lm in landmarks.landmark]
154
+ ys = [lm.y * h for lm in landmarks.landmark]
155
+ return int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))
156
+
157
+ def process_video(video_path: str):
158
+ cap = cv2.VideoCapture(video_path)
159
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
160
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
161
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
162
+
163
+ output_path = video_path.replace(".mp4", "_processed.mp4")
164
+ output_video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
165
+
166
+ while cap.isOpened():
167
+ ret, frame = cap.read()
168
+ if not ret:
169
+ break
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
172
 
173
+ # Face detection
174
+ results = face_detection.process(rgb_frame)
175
+ if results.detections:
176
+ for detection in results.detections:
177
+ # Get face landmarks
178
+ mesh_results = face_mesh.process(rgb_frame)
179
+ if mesh_results.multi_face_landmarks:
180
+ for face_landmarks in mesh_results.multi_face_landmarks:
181
+ x_min, y_min, x_max, y_max = get_face_bbox(face_landmarks, frame.shape)
182
+
183
+ face_crop = rgb_frame[y_min:y_max, x_min:x_max]
184
+ if face_crop.size == 0:
185
+ continue
186
+
187
+ face_tensor = transform(face_crop).unsqueeze(0)
188
+ with torch.no_grad():
189
+ output = torch.softmax(model(face_tensor), dim=1)
190
+ fake_confidence = output[0, 1].item() * 100
191
+ label = "Fake" if fake_confidence > 50 else "Real"
192
+ color = (0, 0, 255) if label == "Fake" else (0, 255, 0)
193
+ label_text = f"{label} ({fake_confidence:.2f}%)"
194
+
195
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
196
+ cv2.putText(frame, label_text, (x_min, y_min - 10),
197
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
198
+
199
+ output_video.write(frame)
200
+
201
+ cap.release()
202
+ output_video.release()
203
+ return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def gradio_interface(video_file):
 
206
  if video_file is None:
207
  return "Error: No video uploaded."
208
+
 
 
209
  with NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
210
  temp_file_path = temp_file.name
211
  with open(video_file, "rb") as uploaded_file:
212
  temp_file.write(uploaded_file.read())
213
+
214
+ output_path = process_video(temp_file_path)
 
 
 
215
  return output_path
216
 
 
217
  iface = gr.Interface(
218
  fn=gradio_interface,
219
  inputs=gr.Video(label="Upload Video"),
220
  outputs=gr.Video(label="Processed Video"),
221
  title="Deepfake Detection",
222
+ description="Upload a video to detect deepfakes using MediaPipe face detection and ResNet-34 model."
 
223
  )
224
 
225
  if __name__ == "__main__":
226
+ iface.launch(share=True)