Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
from tqdm import tqdm | |
from training.detectors import DETECTOR | |
import yaml | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# load the model | |
def load_model(model_name, config_path, weights_path): | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
config['model_name'] = model_name | |
model_class = DETECTOR[model_name] | |
model = model_class(config).to(device) | |
checkpoint = torch.load(weights_path, map_location=device) | |
model.load_state_dict(checkpoint, strict=True) | |
model.eval() | |
return model | |
# preprocess a single video | |
def preprocess_video(video_path, output_dir, frame_num=32): | |
os.makedirs(output_dir, exist_ok=True) | |
frames_dir = os.path.join(output_dir, "frames") | |
os.makedirs(frames_dir, exist_ok=True) | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_indices = np.linspace(0, total_frames - 1, frame_num, dtype=int) | |
# extract frames | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if ret: | |
frame_path = os.path.join(frames_dir, f"frame_{idx:04d}.png") | |
cv2.imwrite(frame_path, frame) | |
frames.append(frame_path) | |
cap.release() | |
return frames | |
# inference on a single video | |
def infer_video(video_path, model, device): | |
output_dir = "temp_video_frames" | |
frames = preprocess_video(video_path, output_dir) | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
probs = [] | |
for frame_path in frames: | |
frame = Image.open(frame_path).convert("RGB") | |
frame = transform(frame).unsqueeze(0).to(device) | |
data_dict = { | |
"image": frame, | |
"label": torch.tensor([0]).to(device), # Dummy label | |
"label_spe": torch.tensor([0]).to(device), # Dummy specific label | |
} | |
with torch.no_grad(): | |
pred_dict = model(data_dict, inference=True) | |
logits = pred_dict["cls"] # Shape: [batch_size, num_classes] | |
prob = torch.softmax(logits, dim=1)[:, 1].item() # Probability of being "fake" | |
probs.append(prob) | |
avg_prob = np.mean(probs) | |
prediction = "Fake" if avg_prob > 0.5 else "Real" | |
return prediction, avg_prob | |
# main function for terminal-based inference | |
def main(video_filename, model_name): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
config_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/config/detector/{model_name}.yaml" | |
weights_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/weights/{model_name}_best.pth" | |
if not os.path.exists(config_path): | |
print(f"Error: Config file for model '{model_name}' not found at {config_path}.") | |
return | |
if not os.path.exists(weights_path): | |
print(f"Error: Weights file for model '{model_name}' not found at {weights_path}.") | |
return | |
model = load_model(model_name, config_path, weights_path) | |
video_path = os.path.join(os.getcwd(), video_filename) | |
if not os.path.exists(video_path): | |
print(f"Error: Video file '{video_filename}' not found in the current directory.") | |
return | |
prediction, confidence = infer_video(video_path, model, device) | |
print(f"Model: {model_name}") | |
print(f"Prediction: {prediction} (Confidence: {confidence:.4f})") | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 3: | |
print("Usage: python inference_script.py <video_filename> <model_name>") | |
print("Available models: xception, meso4, meso4Inception, efficientnetb4, ucf, etc.") | |
else: | |
video_filename = sys.argv[1] | |
model_name = sys.argv[2] | |
main(video_filename, model_name) |