File size: 10,177 Bytes
15ecf69 45ff066 15ecf69 45ff066 15ecf69 3cb18fb 15ecf69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import os
import requests
import json
import re
import gradio as gr
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
class MultilingualLlamaAgent:
"""
A multilingual chatbot powered by Llama hosted on Hugging Face with RAG capabilities.
"""
def __init__(self):
"""Initialize the Hugging Face API client for Llama 3.2 and RAG components."""
print("Initializing Llama 3.2 multilingual agent with RAG...")
# Set up the model ID and API token
self.model_id = os.environ.get('MODEL')
self.api_token = os.environ.get("HF_TOKEN")
self.api_url = f"https://api-inference.huggingface.co/models/{self.model_id}"
# Parameters for text generation
self.max_new_tokens = 540
self.temperature = 0.7
self.top_p = 0.9
# Add greeting message
self.greeting_message = """Hola, entiendo que estás buscando información y asesoramiento. Estoy aquí para ayudarte.
Para que esta conversación sea lo más cómoda para ti, ¿cómo prefieres que te llame o cuáles son tus pronombres?. Si prefieres mantener tu anonimato, puedes usar un nombre ficticio."""
# RAG components
self.embedding_model = SentenceTransformer(
"paraphrase-multilingual-MiniLM-L12-v2"
)
self.knowledge_base = self.load_knowledge_base(os.environ.get('PROTOCOLO'))
self.knowledge_embeddings = self.embed_knowledge_base()
def load_knowledge_base(self, knowledge_base):
"""Load the knowledge base from a provided string."""
try:
# Split the content into chunks (paragraphs)
chunks = [
chunk.strip() for chunk in self.knowledge_base.split("\n\n") if chunk.strip()
]
return chunks
except Exception as e:
print(f"Error processing knowledge base: {str(e)}")
return []
def embed_knowledge_base(self):
"""Create embeddings for the knowledge base chunks."""
if not self.knowledge_base:
return []
return self.embedding_model.encode(self.knowledge_base)
def retrieve_relevant_info(self, query, top_k=3, threshold=0.5):
"""Retrieve the most relevant information from the knowledge base."""
if not self.knowledge_base or not self.knowledge_embeddings.size:
return ""
# Encode the query
query_embedding = self.embedding_model.encode([query])[0]
# Calculate similarity
similarities = cosine_similarity([query_embedding], self.knowledge_embeddings)[
0
]
# Get top-k most similar chunks above threshold
relevant_indices = np.where(similarities > threshold)[0]
if len(relevant_indices) == 0:
return ""
top_indices = relevant_indices[
np.argsort(-similarities[relevant_indices])[:top_k]
]
# Combine the relevant information
relevant_info = "\n\n".join([self.knowledge_base[i] for i in top_indices])
return relevant_info
def extract_answer(self, response_or_json):
try:
# Handle different input types
if hasattr(response_or_json, "json"): # If it's a Response object
data = response_or_json.json()
elif isinstance(response_or_json, str): # If it's a JSON string
data = json.loads(response_or_json)
else: # If it's already a Python object
data = response_or_json
print("data-", data)
# Get the generated text from the first item
generated_text = data[0]["generated_text"]
pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*(.*?)(?:<\|eot_id\|>|$)"
match = re.search(pattern, generated_text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return generated_text # Return full text if pattern not found
except Exception as e:
return f"Error processing the input: {str(e)}"
def generate_response(self, user_input: str) -> str:
"""Generate a response using the Hugging Face Inference API and RAG."""
# Extract the most recent user query from the full context
query = user_input.split("Usuario: ")[-1].split("\nAsistente:")[0].strip()
# Retrieve relevant information from the knowledge base
relevant_info = self.retrieve_relevant_info(query)
tono = os.environ.get('TONO')
tono = f"""
{tono}
"""
# If relevant information is found, include it in the prompt
if relevant_info:
system_context = f"""
Eres un asistente a victimas de violencia laboral que sigue las siguientes instrucciones de tono al reponder las preguntas de los usuarios {tono}
Información relevante para responder a la consulta del usuario:
{relevant_info}
Utiliza la información proporcionada para dar una respuesta más precisa y útil, pero siempre manteniendo el tono y enfoque adecuados.
"""
else:
system_context = f"""
Eres un asistente a victimas de violencia laboral que sigue las siguientes instrucciones de tono al reponder las preguntas de los usuarios {tono}
"""
prompt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_context}<|eot_id|><|start_header_id|>user<|end_header_id|>
{user_input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
try:
# Prepare the payload for the API request
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
},
}
# Set up headers with authorization
headers = {"Authorization": f"Bearer {self.api_token}"}
# Make the API request
response = requests.post(self.api_url, headers=headers, json=payload)
# Check for successful response
if response.status_code == 200:
result = response.json()
print("result-", result)
return self.extract_answer(result)
else:
return f"Error: {response.status_code} - {response.text}"
except Exception as e:
return f"An error occurred: {str(e)}"
def chat_with_agent(message, history):
"""Handle user input and generate a response for the Gradio interface."""
if not agent.api_token:
return history + [
[
message,
"Error: Hugging Face API token is missing. Please set the HF_TOKEN environment variable.",
]
]
# Construct full history for context
full_context = ""
for h in history:
full_context += f"Usuario: {h[0]}\nAsistente: {h[1]}\n"
full_context += f"Usuario: {message}\nAsistente:"
response = agent.generate_response(full_context)
# Return updated history with new message pair
return history + [[message, response]]
# Initialize the agent
agent = MultilingualLlamaAgent()
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
# 🤖 Chatbot basado en Llama para atencion a victimas de acoso laboral.
## ¡Hola!
Gracias por contactarnos. Entendemos que has pasado por una situación incómoda y estamos acá para ofrecerte un espacio seguro y confiable para que puedas compartir tu experiencia.
Antes de empezar, queremos informarte que estás conversando con un chatbot con inteligencia artificial diseñado para ofrecerte información, recursos, apoyo y acompañamiento. Si en algún momento necesitas hablar con una persona real, te indicaremos cómo hacerlo.
Además, queremos asegurarte que toda la información que compartas con nosotros será tratada con la máxima **confidencialidad**. Nadie más tendrá acceso a esta información sin tu consentimiento expreso en esta primera etapa. La información que proporciones se utilizará únicamente para entender mejor lo que te ocurrió y buscar las mejores soluciones para ti. También queremos que sepas que nos guiamos por principios de derechos humanos para que este espacio esté libre de prejuicios, sesgos y estereotipos. Creemos que todas las personas merecen ser tratadas con respeto e igualdad, independientemente de su género, orientación sexual, origen étnico, color de piel, religión o cualquier otra condición. No toleramos ninguna forma de discriminación.
Aquí encontrarás información útil sobre la violencia laboral, tus derechos y los recursos disponibles para que puedas tomar las mejores decisiones de manera informada.
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=500, value=[[None, agent.greeting_message]])
msg = gr.Textbox(placeholder="Escribe tu mensaje aquí...", show_label=False)
with gr.Row():
submit_btn = gr.Button("Enviar")
clear_btn = gr.Button("Limpiar chat")
with gr.Column(scale=1):
gr.Markdown("""
- Este chatbot esta entrenado sobre un modelo Llama.
- Sigue protocolos creados para atencion a victimas de acoso laboral por expertos en la materia.
""")
# Set up event handlers
submit_btn.click(chat_with_agent, [msg, chatbot], [chatbot])
msg.submit(chat_with_agent, [msg, chatbot], [chatbot])
clear_btn.click(
lambda: [[None, agent.greeting_message]], None, chatbot, queue=False
) # Modified to keep greeting
submit_btn.click(lambda: "", None, msg, queue=False)
msg.submit(lambda: "", None, msg, queue=False)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True) |