Spaces:
Sleeping
Sleeping
import streamlit as st | |
import json | |
import boto3 | |
from typing import Dict | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import SagemakerEndpoint | |
from langchain.llms.sagemaker_endpoint import LLMContentHandler | |
st.set_page_config(layout="centered") | |
st.title("Arkham challenge") | |
st.markdown(""" | |
## Propuesta de solución | |
Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG | |
es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte | |
de su ejecución, estos documentos contextuales se utilizan junto con la entrada original | |
para producir la salida final. | |
 | |
Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar | |
más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual | |
para que una LLM pueda contestar las preguntas que le hagamos. | |
""") | |
with open("credentials.json") as file: | |
credentials = json.load(file) | |
sagemaker_client = boto3.client( | |
service_name="sagemaker-runtime", | |
region_name="us-east-1", | |
aws_access_key_id=credentials["aws_access_key_id"], | |
aws_secret_access_key=credentials["aws_secret_access_key"], | |
) | |
translate_client = boto3.client( | |
service_name="translate", | |
region_name="us-east-1", | |
aws_access_key_id=credentials["aws_access_key_id"], | |
aws_secret_access_key=credentials["aws_secret_access_key"], | |
) | |
st.markdown("## QA sobre el contrato") | |
pregunta = st.text_input( | |
label="Escribe tu pregunta", | |
value="¿Quién es el depositario?", | |
help=""" | |
Escribe tu pregunta, por ejemplo: | |
- ¿Cuáles son las obligaciones del arrendatario? | |
- ¿Qué es FIRA? | |
""", | |
) | |
embeddings = HuggingFaceEmbeddings( | |
model_name="intfloat/multilingual-e5-small", | |
) | |
embeddings_db = FAISS.load_local("faiss_index", embeddings) | |
retriever = embeddings_db.as_retriever(search_kwargs={"k": 5}) | |
prompt_template = """ | |
Please answer the question below, using only the context below. | |
Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is. | |
question: {question} | |
context: {context} | |
""" | |
prompt = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
# Endpoint de SageMaker | |
model_kwargs = { | |
"max_new_tokens": 512, | |
"top_p": 0.8, | |
"temperature": 0.8, | |
"repetition_penalty": 1.0, | |
} | |
class ContentHandler(LLMContentHandler): | |
content_type = "application/json" | |
accepts = "application/json" | |
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: | |
input_str = json.dumps( | |
# Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 | |
{"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}} | |
) | |
return input_str.encode("utf-8") | |
def transform_output(self, output: bytes) -> str: | |
response_json = json.loads(output.read().decode("utf-8")) | |
splits = response_json[0]["generated_text"].split("[/INST] ") | |
return splits[1] | |
content_handler = ContentHandler() | |
llm = SagemakerEndpoint( | |
endpoint_name="mistral-langchain", | |
model_kwargs=model_kwargs, | |
content_handler=content_handler, | |
client=sagemaker_client, | |
) | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": prompt}, | |
) | |
question = translate_client.translate_text( | |
Text=pregunta, | |
SourceLanguageCode="es", | |
TargetLanguageCode="en", | |
Settings={ | |
"Formality": "FORMAL", | |
}, | |
).get("TranslatedText") | |
answer = chain.run({"query": question}) | |
respuesta = translate_client.translate_text( | |
Text=answer, | |
SourceLanguageCode="en", | |
TargetLanguageCode="es", | |
Settings={ | |
"Formality": "FORMAL", | |
}, | |
).get("TranslatedText") | |
st.write(respuesta) |