File size: 2,697 Bytes
47a004e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import google.generativeai as genai
import os
from dotenv import load_dotenv
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class InferenceManager:
    def __init__(self):
        # Cargar configuración
        with open('config.yaml', 'r') as file:
            self.config = yaml.safe_load(file)
        
        # Configurar Gemini
        load_dotenv()
        genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
        self.gemini = genai.GenerativeModel('gemini-pro')
        
        # Configurar Mixtral (se cargará bajo demanda)
        self.mixtral = None
        self.mixtral_tokenizer = None
        
        # Estado inicial
        self.current_model = 'gemini'
        self.current_mode = 'seguros'
        
    def load_mixtral(self):
        if not self.mixtral:
            self.mixtral_tokenizer = AutoTokenizer.from_pretrained(
                self.config['models']['mixtral']['name']
            )
            self.mixtral = AutoModelForCausalLM.from_pretrained(
                self.config['models']['mixtral']['name'],
                torch_dtype=torch.float16,
                device_map="auto"
            )
    
    def change_model(self, model_name):
        if model_name in ['gemini', 'mixtral']:
            self.current_model = model_name
            if model_name == 'mixtral':
                self.load_mixtral()
            return True
        return False
    
    def change_mode(self, mode):
        if mode in self.config['modes']:
            self.current_mode = mode
            return True
        return False
    
    def get_response(self, message):
        try:
            # Obtener contexto del modo actual
            context = self.config['modes'][self.current_mode]['context']
            prompt = f"{context}\n\nUsuario: {message}\nAsistente:"
            
            if self.current_model == 'gemini':
                response = self.gemini.generate_content(prompt)
                return response.text
            else:  # mixtral
                inputs = self.mixtral_tokenizer(prompt, return_tensors="pt").to("cuda")
                outputs = self.mixtral.generate(
                    **inputs,
                    max_length=self.config['models']['mixtral']['max_length'],
                    temperature=self.config['models']['mixtral']['temperature'],
                    pad_token_id=self.mixtral_tokenizer.eos_token_id
                )
                return self.mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
                
        except Exception as e:
            print(f"Error en inferencia: {e}")
            return "Lo siento, hubo un error al procesar tu mensaje."