chatbot-web-app / inference.py
salomonsky's picture
Upload inference.py with huggingface_hub
47a004e verified
raw
history blame
2.7 kB
import google.generativeai as genai
import os
from dotenv import load_dotenv
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class InferenceManager:
def __init__(self):
# Cargar configuración
with open('config.yaml', 'r') as file:
self.config = yaml.safe_load(file)
# Configurar Gemini
load_dotenv()
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
self.gemini = genai.GenerativeModel('gemini-pro')
# Configurar Mixtral (se cargará bajo demanda)
self.mixtral = None
self.mixtral_tokenizer = None
# Estado inicial
self.current_model = 'gemini'
self.current_mode = 'seguros'
def load_mixtral(self):
if not self.mixtral:
self.mixtral_tokenizer = AutoTokenizer.from_pretrained(
self.config['models']['mixtral']['name']
)
self.mixtral = AutoModelForCausalLM.from_pretrained(
self.config['models']['mixtral']['name'],
torch_dtype=torch.float16,
device_map="auto"
)
def change_model(self, model_name):
if model_name in ['gemini', 'mixtral']:
self.current_model = model_name
if model_name == 'mixtral':
self.load_mixtral()
return True
return False
def change_mode(self, mode):
if mode in self.config['modes']:
self.current_mode = mode
return True
return False
def get_response(self, message):
try:
# Obtener contexto del modo actual
context = self.config['modes'][self.current_mode]['context']
prompt = f"{context}\n\nUsuario: {message}\nAsistente:"
if self.current_model == 'gemini':
response = self.gemini.generate_content(prompt)
return response.text
else: # mixtral
inputs = self.mixtral_tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = self.mixtral.generate(
**inputs,
max_length=self.config['models']['mixtral']['max_length'],
temperature=self.config['models']['mixtral']['temperature'],
pad_token_id=self.mixtral_tokenizer.eos_token_id
)
return self.mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Error en inferencia: {e}")
return "Lo siento, hubo un error al procesar tu mensaje."