File size: 4,091 Bytes
7b2ccf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6319a21
7b2ccf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st

import json
import boto3
from typing import Dict

from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler

st.set_page_config(layout="centered")
st.title("Arkham challenge")

st.markdown("""
## Propuesta de solución
Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG 
es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte 
de su ejecución, estos documentos contextuales se utilizan junto con la entrada original 
para producir la salida final.

![RAG](https://huggingface.co/blog/assets/12_ray_rag/rag_gif.gif "RAG")

Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar 
más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual 
para que una LLM pueda contestar las preguntas que le hagamos.
""")

with open("credentials.json") as file:
    credentials = json.load(file)

sagemaker_client = boto3.client(
    service_name="sagemaker-runtime",
    region_name="us-east-1",
    aws_access_key_id=credentials["aws_access_key_id"],
    aws_secret_access_key=credentials["aws_secret_access_key"],
)

translate_client = boto3.client(
    service_name="translate",
    region_name="us-east-1",
    aws_access_key_id=credentials["aws_access_key_id"],
    aws_secret_access_key=credentials["aws_secret_access_key"],
)

st.markdown("## QA sobre el contrato")
pregunta = st.text_input(
    label="Escribe tu pregunta",
    value="¿Quién es el depositario?",
    help="""
    Escribe tu pregunta, por ejemplo:
    - ¿Cuáles son las obligaciones del arrendatario?
    - ¿Qué es FIRA?
    """,
)

embeddings = HuggingFaceEmbeddings(
    model_name="intfloat/multilingual-e5-small",
)
embeddings_db = FAISS.load_local("faiss_index", embeddings)
retriever = embeddings_db.as_retriever(search_kwargs={"k": 5})

prompt_template = """
Please answer the question below, using only the context below. 
Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is.

question: {question}

context: {context}
"""
prompt = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

# Endpoint de SageMaker
model_kwargs = {
    "max_new_tokens": 512,
    "top_p": 0.8,
    "temperature": 0.8,
    "repetition_penalty": 1.0,
}


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps(
            # Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
            {"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        splits = response_json[0]["generated_text"].split("[/INST] ")
        return splits[1]


content_handler = ContentHandler()

llm = SagemakerEndpoint(
    endpoint_name="mistral-langchain",
    model_kwargs=model_kwargs,
    content_handler=content_handler,
    client=sagemaker_client,
)

chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs={"prompt": prompt},
)

question = translate_client.translate_text(
    Text=pregunta,
    SourceLanguageCode="es",
    TargetLanguageCode="en",
    Settings={
        "Formality": "FORMAL",
    },
).get("TranslatedText")

answer = chain.run({"query": question})
respuesta = translate_client.translate_text(
    Text=answer,
    SourceLanguageCode="en",
    TargetLanguageCode="es",
    Settings={
        "Formality": "FORMAL",
    },
).get("TranslatedText")

st.write(respuesta)