File size: 6,259 Bytes
979a0f3 31168b8 4382a9e 31168b8 979a0f3 31168b8 4382a9e 31168b8 979a0f3 288e226 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import torch
from src.predict import predict_from_video
from src.islr.islr_model import DummyISLRModel
from huggingface_hub import hf_hub_download
import torch
import os
from dotenv import load_dotenv
import os
# Carga variables del .env
load_dotenv()
# Accede al token como variable de entorno
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
# Crea carpeta si no existe
os.makedirs("models", exist_ok=True)
#device = 'cpu'
#device = 'cuda'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Simulación de diccionario con paths
dataset_models = {
"PERU": {"path":"models/demo_model.pt","num_classes":100},
"WLASL": {"path":"models/demo_model.pt","num_classes":100},
}
# Diccionario de rutas y etiquetas por dataset
dataset_examples = {
"PERU": [
{"label": "📘 **Glosa: `libro`**", "path": "videos/wlasl/book.mp4"},
{"label": "🏠 **Glosa: `casa`**", "path": "videos/wlasl/book.mp4"},
{"label": "📘 **Glosa: `libro2`**", "path": "videos/wlasl/book.mp4"},
{"label": "🏠 **Glosa: `casa2`**", "path": "videos/wlasl/book.mp4"},
],
"WLASL": [
{"label": "📙 **Glosa: `read`**", "path":"videos/wlasl/book.mp4"},
{"label": "🏫 **Glosa: `school`**", "path":"videos/wlasl/book.mp4"},
{"label": "📙 **Glosa: `read2`**", "path":"videos/wlasl/book.mp4"},
{"label": "🏫 **Glosa: `school2`**", "path":"videos/wlasl/book.mp4"},
]
}
# === Carga el modelo según el dataset seleccionado ===
def load_model_and_examples(dataset):
model_path = dataset_models.get(dataset)['path']
num_classes = dataset_models.get(dataset)['num_classes']
print("Downloading..")
model_path = hf_hub_download(repo_id="CristianLazoQuispe/SignERT", filename=model_path,
cache_dir="models", # guarda el archivo y cachea en esa carpeta
token=hf_token
)
print("Downloaded!")
#model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = DummyISLRModel(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model {dataset} Loaded!")
examples = dataset_examples.get(dataset, [{"label": "", "path": ""}, {"label": "", "path": ""}])
return (
model,
gr.update(visible=True),
gr.update(value=examples[0]["path"]),
examples[0]["path"],
gr.update(value=examples[0]["label"]),
gr.update(value=examples[1]["path"]),
examples[1]["path"],
gr.update(value=examples[1]["label"]),
gr.update(value=examples[2]["path"]),
examples[2]["path"],
gr.update(value=examples[2]["label"]),
gr.update(value=examples[3]["path"]),
examples[3]["path"],
gr.update(value=examples[3]["label"]),
gr.update(interactive=True) # activa el botón
)
# === Usamos el modelo cargado en el State ===
def classify_video_with_model(video, model):
top1, top5_df = predict_from_video(video, model=model) # asegúrate de pasar el modelo en `predict_from_video`
return f"Top-1: {top1}", top5_df
with gr.Blocks() as demo:
gr.Markdown("# 🧠 ISLR Demo con Mediapipe y 100 Clases")
gr.Markdown("Sube un video o usa la webcam. El modelo clasificará la seña y mostrará las 5 clases más probables.")
# === Selector de dataset
gr.Markdown("## 📁 Filtrar por Language")
dataset_selector = gr.Dropdown(choices=list(dataset_examples.keys()), value=None, label="Selecciona el lenguaje")
# === Estado del modelo ===
current_model = gr.State()
video_path_1 = gr.State()
video_path_2 = gr.State()
video_path_3 = gr.State()
video_path_4 = gr.State()
# === Entrada de video + salida
with gr.Row():
video_input = gr.Video(sources=["upload", "webcam"], label="🎥 Video de entrada", width=300, height=400)
with gr.Column():
output_text = gr.Text(label="Predicción Top-1")
output_table = gr.Label(num_top_classes=5)
button_classify = gr.Button("🔍 Clasificar",interactive=False)
button_classify.click(
fn=classify_video_with_model,
inputs=[video_input, current_model],
outputs=[output_text, output_table]
)
# === Contenedor dinámico de ejemplos
examples_output = gr.Column(visible=True)
with examples_output:
with gr.Row():
with gr.Column(scale=1, min_width=100):
m1 = gr.Markdown("📘 **Glosa: **")
v1 = gr.Video(interactive=False, width=160, height=120)
b1 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m2 = gr.Markdown("🏠 **Glosa: **")
v2 = gr.Video(interactive=False, width=160, height=120)
b2 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m3 = gr.Markdown("🏠 **Glosa: **")
v3 = gr.Video(interactive=False, width=160, height=120)
b3 = gr.Button("Usar", scale=0)
with gr.Column(scale=1, min_width=100):
m4 = gr.Markdown("🏠 **Glosa: **")
v4 = gr.Video(interactive=False, width=160, height=120)
b4 = gr.Button("Usar", scale=0)
b1.click(fn=lambda path: path, inputs=video_path_1, outputs=video_input)
b2.click(fn=lambda path: path, inputs=video_path_2, outputs=video_input)
b3.click(fn=lambda path: path, inputs=video_path_3, outputs=video_input)
b4.click(fn=lambda path: path, inputs=video_path_4, outputs=video_input)
gr.Markdown("## 📁 Ejemplos de videos")
# === Al cambiar dataset, cargamos modelo + ejemplos
dataset_selector.change(
fn=load_model_and_examples,
inputs=dataset_selector,
outputs=[current_model, examples_output, v1,video_path_1,m1, v2, video_path_2, m2, v3, video_path_3, m3, v4, video_path_4, m4,
button_classify
]
)
if __name__ == "__main__":
demo.launch()#server_port=8080)
|