|
import gradio as gr |
|
import torch |
|
from src.predict import predict_from_video |
|
from src.islr.islr_model import DummyISLRModel |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
import os |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
|
|
|
|
os.makedirs("models", exist_ok=True) |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
dataset_models = { |
|
"PERU": {"path":"models/demo_model.pt","num_classes":100}, |
|
"WLASL": {"path":"models/demo_model.pt","num_classes":100}, |
|
} |
|
|
|
|
|
dataset_examples = { |
|
"PERU": [ |
|
{"label": "📘 **Glosa: `libro`**", "path": "videos/wlasl/book.mp4"}, |
|
{"label": "🏠 **Glosa: `casa`**", "path": "videos/wlasl/book.mp4"}, |
|
{"label": "📘 **Glosa: `libro2`**", "path": "videos/wlasl/book.mp4"}, |
|
{"label": "🏠 **Glosa: `casa2`**", "path": "videos/wlasl/book.mp4"}, |
|
], |
|
"WLASL": [ |
|
{"label": "📙 **Glosa: `read`**", "path":"videos/wlasl/book.mp4"}, |
|
{"label": "🏫 **Glosa: `school`**", "path":"videos/wlasl/book.mp4"}, |
|
{"label": "📙 **Glosa: `read2`**", "path":"videos/wlasl/book.mp4"}, |
|
{"label": "🏫 **Glosa: `school2`**", "path":"videos/wlasl/book.mp4"}, |
|
] |
|
} |
|
|
|
|
|
def load_model_and_examples(dataset): |
|
model_path = dataset_models.get(dataset)['path'] |
|
num_classes = dataset_models.get(dataset)['num_classes'] |
|
print("Downloading..") |
|
model_path = hf_hub_download(repo_id="CristianLazoQuispe/SignERT", filename=model_path, |
|
cache_dir="models", |
|
token=hf_token |
|
) |
|
print("Downloaded!") |
|
|
|
|
|
|
|
model = DummyISLRModel(num_classes=num_classes) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
print(f"Model {dataset} Loaded!") |
|
examples = dataset_examples.get(dataset, [{"label": "", "path": ""}, {"label": "", "path": ""}]) |
|
return ( |
|
model, |
|
gr.update(visible=True), |
|
gr.update(value=examples[0]["path"]), |
|
examples[0]["path"], |
|
gr.update(value=examples[0]["label"]), |
|
gr.update(value=examples[1]["path"]), |
|
examples[1]["path"], |
|
gr.update(value=examples[1]["label"]), |
|
gr.update(value=examples[2]["path"]), |
|
examples[2]["path"], |
|
gr.update(value=examples[2]["label"]), |
|
gr.update(value=examples[3]["path"]), |
|
examples[3]["path"], |
|
gr.update(value=examples[3]["label"]), |
|
gr.update(interactive=True) |
|
) |
|
|
|
|
|
def classify_video_with_model(video, model): |
|
top1, top5_df = predict_from_video(video, model=model) |
|
return f"Top-1: {top1}", top5_df |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 🧠 ISLR Demo con Mediapipe y 100 Clases") |
|
gr.Markdown("Sube un video o usa la webcam. El modelo clasificará la seña y mostrará las 5 clases más probables.") |
|
|
|
|
|
gr.Markdown("## 📁 Filtrar por Language") |
|
dataset_selector = gr.Dropdown(choices=list(dataset_examples.keys()), value=None, label="Selecciona el lenguaje") |
|
|
|
|
|
|
|
current_model = gr.State() |
|
video_path_1 = gr.State() |
|
video_path_2 = gr.State() |
|
video_path_3 = gr.State() |
|
video_path_4 = gr.State() |
|
|
|
|
|
with gr.Row(): |
|
video_input = gr.Video(sources=["upload", "webcam"], label="🎥 Video de entrada", width=300, height=400) |
|
with gr.Column(): |
|
output_text = gr.Text(label="Predicción Top-1") |
|
output_table = gr.Label(num_top_classes=5) |
|
button_classify = gr.Button("🔍 Clasificar",interactive=False) |
|
|
|
button_classify.click( |
|
fn=classify_video_with_model, |
|
inputs=[video_input, current_model], |
|
outputs=[output_text, output_table] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
examples_output = gr.Column(visible=True) |
|
|
|
with examples_output: |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=100): |
|
m1 = gr.Markdown("📘 **Glosa: **") |
|
v1 = gr.Video(interactive=False, width=160, height=120) |
|
b1 = gr.Button("Usar", scale=0) |
|
with gr.Column(scale=1, min_width=100): |
|
m2 = gr.Markdown("🏠 **Glosa: **") |
|
v2 = gr.Video(interactive=False, width=160, height=120) |
|
b2 = gr.Button("Usar", scale=0) |
|
with gr.Column(scale=1, min_width=100): |
|
m3 = gr.Markdown("🏠 **Glosa: **") |
|
v3 = gr.Video(interactive=False, width=160, height=120) |
|
b3 = gr.Button("Usar", scale=0) |
|
with gr.Column(scale=1, min_width=100): |
|
m4 = gr.Markdown("🏠 **Glosa: **") |
|
v4 = gr.Video(interactive=False, width=160, height=120) |
|
b4 = gr.Button("Usar", scale=0) |
|
|
|
b1.click(fn=lambda path: path, inputs=video_path_1, outputs=video_input) |
|
b2.click(fn=lambda path: path, inputs=video_path_2, outputs=video_input) |
|
b3.click(fn=lambda path: path, inputs=video_path_3, outputs=video_input) |
|
b4.click(fn=lambda path: path, inputs=video_path_4, outputs=video_input) |
|
|
|
gr.Markdown("## 📁 Ejemplos de videos") |
|
|
|
dataset_selector.change( |
|
fn=load_model_and_examples, |
|
inputs=dataset_selector, |
|
outputs=[current_model, examples_output, v1,video_path_1,m1, v2, video_path_2, m2, v3, video_path_3, m3, v4, video_path_4, m4, |
|
button_classify |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|