CristianLazoQuispe's picture
use env variable
4382a9e
import gradio as gr
import torch
from src.predict import predict_from_video
from src.islr.islr_model import DummyISLRModel
from huggingface_hub import hf_hub_download
import torch
import os
from dotenv import load_dotenv
import os
# Carga variables del .env
load_dotenv()
# Accede al token como variable de entorno
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
# Crea carpeta si no existe
os.makedirs("models", exist_ok=True)
#device = 'cpu'
#device = 'cuda'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Simulación de diccionario con paths
dataset_models = {
"PERU": {"path":"models/demo_model.pt","num_classes":100},
"WLASL": {"path":"models/demo_model.pt","num_classes":100},
}
# Diccionario de rutas y etiquetas por dataset
dataset_examples = {
"PERU": [
{"label": "📘 **Glosa: `libro`**", "path": "videos/wlasl/book.mp4"},
{"label": "🏠 **Glosa: `casa`**", "path": "videos/wlasl/book.mp4"},
{"label": "📘 **Glosa: `libro2`**", "path": "videos/wlasl/book.mp4"},
{"label": "🏠 **Glosa: `casa2`**", "path": "videos/wlasl/book.mp4"},
],
"WLASL": [
{"label": "📙 **Glosa: `read`**", "path":"videos/wlasl/book.mp4"},
{"label": "🏫 **Glosa: `school`**", "path":"videos/wlasl/book.mp4"},
{"label": "📙 **Glosa: `read2`**", "path":"videos/wlasl/book.mp4"},
{"label": "🏫 **Glosa: `school2`**", "path":"videos/wlasl/book.mp4"},
]
}
# === Carga el modelo según el dataset seleccionado ===
def load_model_and_examples(dataset):
model_path = dataset_models.get(dataset)['path']
num_classes = dataset_models.get(dataset)['num_classes']
print("Downloading..")
model_path = hf_hub_download(repo_id="CristianLazoQuispe/SignERT", filename=model_path,
cache_dir="models", # guarda el archivo y cachea en esa carpeta
token=hf_token
)
print("Downloaded!")
#model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = DummyISLRModel(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model {dataset} Loaded!")
examples = dataset_examples.get(dataset, [{"label": "", "path": ""}, {"label": "", "path": ""}])
return (
model,
gr.update(visible=True),
gr.update(value=examples[0]["path"]),
examples[0]["path"],
gr.update(value=examples[0]["label"]),
gr.update(value=examples[1]["path"]),
examples[1]["path"],
gr.update(value=examples[1]["label"]),
gr.update(value=examples[2]["path"]),
examples[2]["path"],
gr.update(value=examples[2]["label"]),
gr.update(value=examples[3]["path"]),
examples[3]["path"],
gr.update(value=examples[3]["label"]),
gr.update(interactive=True) # activa el botón
)
# === Usamos el modelo cargado en el State ===
def classify_video_with_model(video, model):
top1, top5_df = predict_from_video(video, model=model) # asegúrate de pasar el modelo en `predict_from_video`
return f"Top-1: {top1}", top5_df
with gr.Blocks() as demo:
gr.Markdown("# 🧠 ISLR Demo con Mediapipe y 100 Clases")
gr.Markdown("Sube un video o usa la webcam. El modelo clasificará la seña y mostrará las 5 clases más probables.")
# === Selector de dataset
gr.Markdown("## 📁 Filtrar por Language")
dataset_selector = gr.Dropdown(choices=list(dataset_examples.keys()), value=None, label="Selecciona el lenguaje")
# === Estado del modelo ===
current_model = gr.State()
video_path_1 = gr.State()
video_path_2 = gr.State()
video_path_3 = gr.State()
video_path_4 = gr.State()
# === Entrada de video + salida
with gr.Row():
video_input = gr.Video(sources=["upload", "webcam"], label="🎥 Video de entrada", width=300, height=400)
with gr.Column():
output_text = gr.Text(label="Predicción Top-1")
output_table = gr.Label(num_top_classes=5)
button_classify = gr.Button("🔍 Clasificar",interactive=False)
button_classify.click(
fn=classify_video_with_model,
inputs=[video_input, current_model],
outputs=[output_text, output_table]
)
# === Contenedor dinámico de ejemplos
examples_output = gr.Column(visible=True)
with examples_output:
with gr.Row():
with gr.Column(scale=1, min_width=100):
m1 = gr.Markdown("📘 **Glosa: **")
v1 = gr.Video(interactive=False, width=160, height=120)
b1 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m2 = gr.Markdown("🏠 **Glosa: **")
v2 = gr.Video(interactive=False, width=160, height=120)
b2 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m3 = gr.Markdown("🏠 **Glosa: **")
v3 = gr.Video(interactive=False, width=160, height=120)
b3 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m4 = gr.Markdown("🏠 **Glosa: **")
v4 = gr.Video(interactive=False, width=160, height=120)
b4 = gr.Button("Usar", scale=0)
b1.click(fn=lambda path: path, inputs=video_path_1, outputs=video_input)
b2.click(fn=lambda path: path, inputs=video_path_2, outputs=video_input)
b3.click(fn=lambda path: path, inputs=video_path_3, outputs=video_input)
b4.click(fn=lambda path: path, inputs=video_path_4, outputs=video_input)
gr.Markdown("## 📁 Ejemplos de videos")
# === Al cambiar dataset, cargamos modelo + ejemplos
dataset_selector.change(
fn=load_model_and_examples,
inputs=dataset_selector,
outputs=[current_model, examples_output, v1,video_path_1,m1, v2, video_path_2, m2, v3, video_path_3, m3, v4, video_path_4, m4,
button_classify
]
)
if __name__ == "__main__":
demo.launch()#server_port=8080)