Spaces:
Runtime error
Runtime error
import google.generativeai as genai | |
import os | |
from dotenv import load_dotenv | |
import yaml | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class InferenceManager: | |
def __init__(self): | |
# Cargar configuración | |
with open('config.yaml', 'r') as file: | |
self.config = yaml.safe_load(file) | |
# Configurar Gemini | |
load_dotenv() | |
genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) | |
self.gemini = genai.GenerativeModel('gemini-pro') | |
# Configurar Mixtral (se cargará bajo demanda) | |
self.mixtral = None | |
self.mixtral_tokenizer = None | |
# Estado inicial | |
self.current_model = 'gemini' | |
self.current_mode = 'seguros' | |
def load_mixtral(self): | |
if not self.mixtral: | |
self.mixtral_tokenizer = AutoTokenizer.from_pretrained( | |
self.config['models']['mixtral']['name'] | |
) | |
self.mixtral = AutoModelForCausalLM.from_pretrained( | |
self.config['models']['mixtral']['name'], | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
def change_model(self, model_name): | |
if model_name in ['gemini', 'mixtral']: | |
self.current_model = model_name | |
if model_name == 'mixtral': | |
self.load_mixtral() | |
return True | |
return False | |
def change_mode(self, mode): | |
if mode in self.config['modes']: | |
self.current_mode = mode | |
return True | |
return False | |
def get_response(self, message): | |
try: | |
# Obtener contexto del modo actual | |
context = self.config['modes'][self.current_mode]['context'] | |
prompt = f"{context}\n\nUsuario: {message}\nAsistente:" | |
if self.current_model == 'gemini': | |
response = self.gemini.generate_content(prompt) | |
return response.text | |
else: # mixtral | |
inputs = self.mixtral_tokenizer(prompt, return_tensors="pt").to("cuda") | |
outputs = self.mixtral.generate( | |
**inputs, | |
max_length=self.config['models']['mixtral']['max_length'], | |
temperature=self.config['models']['mixtral']['temperature'], | |
pad_token_id=self.mixtral_tokenizer.eos_token_id | |
) | |
return self.mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"Error en inferencia: {e}") | |
return "Lo siento, hubo un error al procesar tu mensaje." |