|
import streamlit as st |
|
import yaml |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
class ModelManager: |
|
def __init__(self, model_name="microsoft/Phi-4-mini-instruct"): |
|
|
|
self.models = { |
|
"microsoft/Phi-4-mini-instruct": "microsoft/Phi-4-mini-instruct", |
|
"microsoft/Phi-4-multimodal": "microsoft/Phi-4-multimodal", |
|
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct" |
|
} |
|
self.current_model_name = model_name |
|
self.tokenizer = None |
|
self.model = None |
|
self.load_model(model_name) |
|
|
|
def load_model(self, model_name): |
|
self.current_model_name = model_name |
|
model_path = self.models[model_name] |
|
st.info(f"Cargando modelo: {model_name} ...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
def generate(self, prompt, max_length=50, temperature=0.7): |
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
outputs = self.model.generate(inputs["input_ids"], max_length=max_length, temperature=temperature) |
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def switch_model(self, model_name): |
|
if model_name in self.models: |
|
self.load_model(model_name) |
|
else: |
|
raise ValueError(f"El modelo {model_name} no est谩 disponible.") |
|
|
|
@st.cache_data |
|
def load_prompts(): |
|
with open("prompt.yml", "r", encoding="utf-8") as f: |
|
prompts = yaml.safe_load(f) |
|
return prompts |
|
|
|
def main(): |
|
st.title("Switcher de Modelos de Transformers") |
|
|
|
|
|
prompts_config = load_prompts() |
|
|
|
|
|
st.sidebar.title("Selecci贸n de Modelo") |
|
model_choice = st.sidebar.selectbox("Selecciona un modelo", list(prompts_config.keys())) |
|
|
|
|
|
model_manager = ModelManager(model_name=model_choice) |
|
|
|
|
|
style_prompt = prompts_config.get(model_choice, prompts_config.get("default_prompt", "")) |
|
|
|
st.write(f"**Modelo en uso:** {model_choice}") |
|
|
|
|
|
user_prompt = st.text_area("Ingresa tu prompt:", value=style_prompt) |
|
|
|
max_length = st.slider("Longitud m谩xima", min_value=10, max_value=200, value=50) |
|
temperature = st.slider("Temperatura", min_value=0.1, max_value=1.0, value=0.7) |
|
|
|
if st.button("Generar respuesta"): |
|
result = model_manager.generate(user_prompt, max_length=max_length, temperature=temperature) |
|
st.text_area("Salida", value=result, height=200) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|