Spaces:
Build error
Build error
import argparse | |
import os | |
import re | |
import time | |
import torch | |
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video | |
from training.zoo.classifiers import DeepFakeClassifier | |
import gradio as gr | |
def model_fn(model_dir): | |
model_path = os.path.join(model_dir, 'b7_ns_best.pth') | |
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU | |
checkpoint = torch.load(model_path, map_location="cpu") | |
state_dict = checkpoint.get("state_dict", checkpoint) | |
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True) | |
model.eval() | |
del checkpoint | |
#models.append(model.half()) | |
return model | |
def convert_result(pred, class_names=["Real", "Fake"]): | |
preds = [pred, 1 - pred] | |
assert len(class_names) == len(preds), "Class / Prediction should have the same length" | |
return {n: float(p) for n, p in zip(class_names, preds)} | |
def predict_fn(video): | |
start = time.time() | |
prediction = predict_on_video(face_extractor=meta["face_extractor"], | |
video_path=video, | |
batch_size=meta["fps"], | |
input_size=meta["input_size"], | |
models=model, | |
strategy=meta["strategy"], | |
apply_compression=False, | |
device='cpu') | |
elapsed_time = round(time.time() - start, 2) | |
prediction = convert_result(prediction) | |
return prediction, elapsed_time | |
# Create title, description and article strings | |
title = "Deepfake Detector (private)" | |
description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)" | |
example_list = ["examples/" + str(p) for p in os.listdir("examples/")] | |
# Environments | |
model_dir = 'weights' | |
frames_per_video = 32 | |
video_reader = VideoReader() | |
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) | |
face_extractor = FaceExtractor(video_read_fn) | |
input_size = 380 | |
strategy = confident_strategy | |
class_names = ["Real", "Fake"] | |
meta = {"fps": 32, | |
"face_extractor": face_extractor, | |
"input_size": input_size, | |
"strategy": strategy} | |
model = model_fn(model_dir) | |
""" | |
if __name__ == '__main__': | |
video_path = "examples/nlurbvsozt.mp4" | |
model = model_fn(model_dir) | |
a, b = predict_fn(video_path) | |
print(a, b) | |
""" | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict_fn, # mapping function from input to output | |
inputs=gr.Video(), | |
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs | |
examples=example_list, | |
title=title, | |
description=description) | |
# Launch the demo! | |
demo.launch(debug=False,) # Hugging face space don't need shareable_links | |