import google.generativeai as genai import os from dotenv import load_dotenv import yaml from transformers import AutoModelForCausalLM, AutoTokenizer import torch class InferenceManager: def __init__(self): # Cargar configuración with open('config.yaml', 'r') as file: self.config = yaml.safe_load(file) # Configurar Gemini load_dotenv() genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) self.gemini = genai.GenerativeModel('gemini-pro') # Configurar Mixtral (se cargará bajo demanda) self.mixtral = None self.mixtral_tokenizer = None # Estado inicial self.current_model = 'gemini' self.current_mode = 'seguros' def load_mixtral(self): if not self.mixtral: self.mixtral_tokenizer = AutoTokenizer.from_pretrained( self.config['models']['mixtral']['name'] ) self.mixtral = AutoModelForCausalLM.from_pretrained( self.config['models']['mixtral']['name'], torch_dtype=torch.float16, device_map="auto" ) def change_model(self, model_name): if model_name in ['gemini', 'mixtral']: self.current_model = model_name if model_name == 'mixtral': self.load_mixtral() return True return False def change_mode(self, mode): if mode in self.config['modes']: self.current_mode = mode return True return False def get_response(self, message): try: # Obtener contexto del modo actual context = self.config['modes'][self.current_mode]['context'] prompt = f"{context}\n\nUsuario: {message}\nAsistente:" if self.current_model == 'gemini': response = self.gemini.generate_content(prompt) return response.text else: # mixtral inputs = self.mixtral_tokenizer(prompt, return_tensors="pt").to("cuda") outputs = self.mixtral.generate( **inputs, max_length=self.config['models']['mixtral']['max_length'], temperature=self.config['models']['mixtral']['temperature'], pad_token_id=self.mixtral_tokenizer.eos_token_id ) return self.mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: print(f"Error en inferencia: {e}") return "Lo siento, hubo un error al procesar tu mensaje."