salomonsky commited on
Commit
47a004e
verified
1 Parent(s): c0f5a6c

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +73 -0
inference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ import os
3
+ from dotenv import load_dotenv
4
+ import yaml
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import torch
7
+
8
+ class InferenceManager:
9
+ def __init__(self):
10
+ # Cargar configuraci贸n
11
+ with open('config.yaml', 'r') as file:
12
+ self.config = yaml.safe_load(file)
13
+
14
+ # Configurar Gemini
15
+ load_dotenv()
16
+ genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
17
+ self.gemini = genai.GenerativeModel('gemini-pro')
18
+
19
+ # Configurar Mixtral (se cargar谩 bajo demanda)
20
+ self.mixtral = None
21
+ self.mixtral_tokenizer = None
22
+
23
+ # Estado inicial
24
+ self.current_model = 'gemini'
25
+ self.current_mode = 'seguros'
26
+
27
+ def load_mixtral(self):
28
+ if not self.mixtral:
29
+ self.mixtral_tokenizer = AutoTokenizer.from_pretrained(
30
+ self.config['models']['mixtral']['name']
31
+ )
32
+ self.mixtral = AutoModelForCausalLM.from_pretrained(
33
+ self.config['models']['mixtral']['name'],
34
+ torch_dtype=torch.float16,
35
+ device_map="auto"
36
+ )
37
+
38
+ def change_model(self, model_name):
39
+ if model_name in ['gemini', 'mixtral']:
40
+ self.current_model = model_name
41
+ if model_name == 'mixtral':
42
+ self.load_mixtral()
43
+ return True
44
+ return False
45
+
46
+ def change_mode(self, mode):
47
+ if mode in self.config['modes']:
48
+ self.current_mode = mode
49
+ return True
50
+ return False
51
+
52
+ def get_response(self, message):
53
+ try:
54
+ # Obtener contexto del modo actual
55
+ context = self.config['modes'][self.current_mode]['context']
56
+ prompt = f"{context}\n\nUsuario: {message}\nAsistente:"
57
+
58
+ if self.current_model == 'gemini':
59
+ response = self.gemini.generate_content(prompt)
60
+ return response.text
61
+ else: # mixtral
62
+ inputs = self.mixtral_tokenizer(prompt, return_tensors="pt").to("cuda")
63
+ outputs = self.mixtral.generate(
64
+ **inputs,
65
+ max_length=self.config['models']['mixtral']['max_length'],
66
+ temperature=self.config['models']['mixtral']['temperature'],
67
+ pad_token_id=self.mixtral_tokenizer.eos_token_id
68
+ )
69
+ return self.mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+
71
+ except Exception as e:
72
+ print(f"Error en inferencia: {e}")
73
+ return "Lo siento, hubo un error al procesar tu mensaje."