crodri commited on
Commit
b163a23
·
verified ·
1 Parent(s): b87ec46

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +6 -5
  2. app.py +261 -0
  3. faiss_index.bin/index.faiss +3 -0
  4. faiss_index.bin/index.pkl +3 -0
  5. handler.py +14 -0
  6. input_reader.py +22 -0
  7. rag.py +162 -0
  8. requirements.txt +14 -0
  9. utils.py +33 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: RAG4RENFE
3
- emoji:
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: EADOP RAG
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.24.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio.components import Textbox, Button, Slider, Checkbox
4
+ from AinaTheme import theme
5
+ from urllib.error import HTTPError
6
+
7
+ from rag import RAG
8
+ from utils import setup
9
+
10
+ MAX_NEW_TOKENS = 700
11
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
+
13
+ setup()
14
+
15
+
16
+ rag = RAG(
17
+ hf_token=os.getenv("HF_TOKEN"),
18
+ embeddings_model=os.getenv("EMBEDDINGS"),
19
+ model_name=os.getenv("MODEL"),
20
+ rerank_model=os.getenv("RERANK_MODEL"),
21
+ rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))
22
+ )
23
+
24
+
25
+ def generate(prompt, model_parameters):
26
+ try:
27
+ output, context, source = rag.get_response(prompt, model_parameters)
28
+ return output, context, source
29
+ except HTTPError as err:
30
+ if err.code == 400:
31
+ gr.Warning(
32
+ "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
33
+ )
34
+ except:
35
+ gr.Warning(
36
+ "Inference endpoint is not available right now. Please try again later."
37
+ )
38
+ return None, None, None
39
+
40
+
41
+ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
42
+ if input_.strip() == "":
43
+ gr.Warning("Not possible to inference an empty input")
44
+ return None
45
+
46
+
47
+ model_parameters = {
48
+ "NUM_CHUNKS": num_chunks,
49
+ "max_new_tokens": max_new_tokens,
50
+ "repetition_penalty": repetition_penalty,
51
+ "top_k": top_k,
52
+ "top_p": top_p,
53
+ "do_sample": do_sample,
54
+ "temperature": temperature
55
+ }
56
+
57
+ output, context, source = generate(input_, model_parameters)
58
+ sources_markup = ""
59
+
60
+ for url in source:
61
+ sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'
62
+
63
+ return output, sources_markup, context
64
+ # return output.strip(), sources_markup, context
65
+
66
+
67
+ def change_interactive(text):
68
+ if len(text) == 0:
69
+ return gr.update(interactive=True), gr.update(interactive=False)
70
+ return gr.update(interactive=True), gr.update(interactive=True)
71
+
72
+
73
+ def clear():
74
+ return (
75
+ None,
76
+ None,
77
+ None,
78
+ None,
79
+ gr.Slider(value=2.0),
80
+ gr.Slider(value=MAX_NEW_TOKENS),
81
+ gr.Slider(value=1.0),
82
+ gr.Slider(value=50),
83
+ gr.Slider(value=0.99),
84
+ gr.Checkbox(value=False),
85
+ gr.Slider(value=0.35),
86
+ )
87
+
88
+
89
+ def gradio_app():
90
+ with gr.Blocks(theme=theme) as demo:
91
+ with gr.Row():
92
+ with gr.Column(scale=0.1):
93
+ gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
94
+ with gr.Column():
95
+ gr.Markdown(
96
+ """# Demo de Retrieval-Augmented Generation per documents legals
97
+ 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
98
+ en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
99
+ fent servir només la informació existent en els documents del repositori.
100
+
101
+ 🎯 **Objectiu:** Aquest és un demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el
102
+ repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta versió explora prop de 2000 documents en català,
103
+ i genera la resposta fent servir el model Salamandra-7b-aligned-EADOP, el model BSC-LT/salamandra-7b-instruct alineat amb el dataset de alinia/EADOP-RAG-out-of-domain.
104
+
105
+ ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
106
+ Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
107
+ contacteu amb nosaltres a Langtech.
108
+
109
+
110
+ 👀 **Mes informació en els informes de: ** [RAG](https://drive.google.com/file/d/11MgXQXAxfhkqbrx8syrKtmBrNP_6Qhx9/view?usp=sharing) i [Alineació](https://drive.google.com/file/d/1VUqHKO-gDmgMozK-Al83a2kh4Fr70pHh/view?usp=sharing) en pdf (ànglés).
111
+ """
112
+ )
113
+ with gr.Row(equal_height=True):
114
+ with gr.Column(variant="panel"):
115
+ input_ = Textbox(
116
+ lines=11,
117
+ label="Input",
118
+ placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
119
+ # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
120
+ )
121
+ with gr.Row(variant="panel"):
122
+ clear_btn = Button(
123
+ "Clear",
124
+ )
125
+ submit_btn = Button("Submit", variant="primary", interactive=False)
126
+
127
+ with gr.Row(variant="panel"):
128
+ with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
129
+ num_chunks = Slider(
130
+ minimum=1,
131
+ maximum=6,
132
+ step=1,
133
+ value=2,
134
+ label="Number of chunks"
135
+ )
136
+ max_new_tokens = Slider(
137
+ minimum=50,
138
+ maximum=2000,
139
+ step=1,
140
+ value=MAX_NEW_TOKENS,
141
+ label="Max tokens"
142
+ )
143
+ repetition_penalty = Slider(
144
+ minimum=0.1,
145
+ maximum=2.0,
146
+ step=0.1,
147
+ value=1.0,
148
+ label="Repetition penalty"
149
+ )
150
+ top_k = Slider(
151
+ minimum=1,
152
+ maximum=100,
153
+ step=1,
154
+ value=50,
155
+ label="Top k"
156
+ )
157
+ top_p = Slider(
158
+ minimum=0.01,
159
+ maximum=0.99,
160
+ value=0.99,
161
+ label="Top p"
162
+ )
163
+ do_sample = Checkbox(
164
+ value=False,
165
+ label="Do sample"
166
+ )
167
+ temperature = Slider(
168
+ minimum=0.1,
169
+ maximum=1,
170
+ value=0.35,
171
+ label="Temperature"
172
+ )
173
+
174
+ parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
175
+
176
+ with gr.Column(variant="panel"):
177
+ output = Textbox(
178
+ lines=10,
179
+ label="Output",
180
+ interactive=False,
181
+ show_copy_button=True
182
+ )
183
+ with gr.Accordion("Sources and context:", open=False):
184
+ source_context = gr.Markdown(
185
+ label="Sources",
186
+ show_label=False,
187
+ )
188
+ with gr.Accordion("See full context evaluation:", open=False):
189
+ context_evaluation = gr.Markdown(
190
+ label="Full context",
191
+ show_label=False,
192
+ # interactive=False,
193
+ # autoscroll=False,
194
+ # show_copy_button=True
195
+ )
196
+
197
+
198
+ input_.change(
199
+ fn=change_interactive,
200
+ inputs=[input_],
201
+ outputs=[clear_btn, submit_btn],
202
+ api_name=False,
203
+ )
204
+
205
+ input_.change(
206
+ fn=None,
207
+ inputs=[input_],
208
+ api_name=False,
209
+ js="""(i, m) => {
210
+ document.getElementById('inputlenght').textContent = i.length + ' '
211
+ document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
212
+ }""",
213
+ )
214
+
215
+ clear_btn.click(
216
+ fn=clear,
217
+ inputs=[],
218
+ outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
219
+ queue=False,
220
+ api_name=False
221
+ )
222
+
223
+ submit_btn.click(
224
+ fn=submit_input,
225
+ inputs=[input_]+ parameters_compontents,
226
+ outputs=[output, source_context, context_evaluation],
227
+ api_name="get-results"
228
+ )
229
+
230
+ with gr.Row():
231
+ with gr.Column(scale=0.5):
232
+ gr.Examples(
233
+ examples=[
234
+ ["""Què és l'EADOP (Entitat Autònoma del Diari Oficial i de Publicacions)?"""],
235
+ ],
236
+ inputs=input_,
237
+ outputs=[output, source_context, context_evaluation],
238
+ fn=submit_input,
239
+ )
240
+ gr.Examples(
241
+ examples=[
242
+ ["""Com es pot inscriure una persona al Registre de catalans i catalanes residents a l'exterior?"""],
243
+ ],
244
+ inputs=input_,
245
+ outputs=[output, source_context, context_evaluation],
246
+ fn=submit_input,
247
+ )
248
+ gr.Examples(
249
+ examples=[
250
+ ["""Quina és la finalitat del Servei Meterològic de Catalunya ?"""],
251
+ ],
252
+ inputs=input_,
253
+ outputs=[output, source_context, context_evaluation],
254
+ fn=submit_input,
255
+ )
256
+
257
+ demo.launch(show_api=True)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ gradio_app()
faiss_index.bin/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c67af2904227e08af1cf2fca5dbb61bc2f7b3322651d0cebefe59f519dcb34e
3
+ size 3793965
faiss_index.bin/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d74bb71eeb7484e51311f91c82bd27419bfc8f5ee65d383419fbae9e9538c87a
3
+ size 2921777
handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class ContentHandler():
4
+ content_type = "application/json"
5
+ accepts = "application/json"
6
+
7
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
+ input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
+ return input_str.encode('utf-8')
10
+
11
+ def transform_output(self, output: bytes) -> str:
12
+ response_json = json.loads(output.read().decode("utf-8"))
13
+ return response_json[0]["generated_text"]
14
+
input_reader.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
+ from llama_index.core.readers import SimpleDirectoryReader
5
+ from llama_index.core.schema import Document
6
+ from llama_index.core import Settings
7
+
8
+
9
+ class InputReader:
10
+ def __init__(self, input_dir: str) -> None:
11
+ self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
+
13
+ def parse_documents(
14
+ self,
15
+ show_progress: bool = True,
16
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
17
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
+ ) -> List[Document]:
19
+ Settings.chunk_size = chunk_size
20
+ Settings.chunk_overlap = chunk_overlap
21
+ documents = self.reader.load_data(show_progress=show_progress)
22
+ return documents
rag.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import requests
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ from openai import OpenAI
7
+
8
+
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+
12
+
13
+ class RAG:
14
+ NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
15
+
16
+ #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed
17
+ #vectorstore = "vectorestore" # CA only
18
+ vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
19
+
20
+ def __init__(self, hf_token, embeddings_model, model_name, rerank_model, rerank_number_contexts):
21
+
22
+
23
+ self.model_name = model_name
24
+ self.hf_token = hf_token
25
+ self.rerank_model = rerank_model
26
+ self.rerank_number_contexts = rerank_number_contexts
27
+
28
+ # load vectore store
29
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
30
+ self.vectore_store = FAISS.load_local(self.vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
31
+
32
+ logging.info("RAG loaded!")
33
+
34
+ def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
35
+ """
36
+ Rerank the contexts based on their relevance to the given instruction.
37
+ """
38
+
39
+ rerank_model = self.rerank_model
40
+
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model)
43
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
44
+
45
+ def get_score(query, passage):
46
+ """Calculate the relevance score of a passage with respect to a query."""
47
+
48
+
49
+ inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
50
+
51
+
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+
55
+
56
+ logits = outputs.logits
57
+
58
+
59
+ score = logits.view(-1, ).float()
60
+
61
+
62
+ return score
63
+
64
+ scores = [get_score(instruction, c[0].page_content) for c in contexts]
65
+ combined = list(zip(contexts, scores))
66
+ sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
67
+ sorted_texts, _ = zip(*sorted_combined)
68
+
69
+ return sorted_texts[:number_of_contexts]
70
+
71
+ def get_context(self, instruction, number_of_contexts=2):
72
+ """Retrieve the most relevant contexts for a given instruction."""
73
+ documentos = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
74
+
75
+ documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
76
+
77
+ print("Reranked documents")
78
+ return documentos
79
+
80
+ def predict_dolly(self, instruction, context, model_parameters):
81
+
82
+ api_key = os.getenv("HF_TOKEN")
83
+
84
+
85
+ headers = {
86
+ "Accept" : "application/json",
87
+ "Authorization": f"Bearer {api_key}",
88
+ "Content-Type": "application/json"
89
+ }
90
+
91
+ query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
92
+ #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
93
+
94
+
95
+ payload = {
96
+ "inputs": query,
97
+ "parameters": model_parameters
98
+ }
99
+
100
+ response = requests.post(self.model_name, headers=headers, json=payload)
101
+
102
+ return response.json()[0]["generated_text"].split("###")[-1][8:]
103
+
104
+ def predict_completion(self, instruction, context, model_parameters):
105
+
106
+ client = OpenAI(
107
+ base_url=os.getenv("MODEL"),
108
+ api_key=os.getenv("HF_TOKEN")
109
+ )
110
+
111
+ query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
112
+
113
+ chat_completion = client.chat.completions.create(
114
+ model="tgi",
115
+ messages=[
116
+ {"role": "user", "content": instruction}
117
+ ],
118
+ temperature=model_parameters["temperature"],
119
+ max_tokens=model_parameters["max_new_tokens"],
120
+ stream=False,
121
+ stop=["<|im_end|>"],
122
+ extra_body = {
123
+ "presence_penalty": model_parameters["repetition_penalty"] - 2,
124
+ "do_sample": False
125
+ }
126
+ )
127
+
128
+ response = chat_completion.choices[0].message.content
129
+
130
+ return response
131
+
132
+
133
+ def beautiful_context(self, docs):
134
+
135
+ text_context = ""
136
+
137
+ full_context = ""
138
+ source_context = []
139
+ for doc in docs:
140
+ text_context += doc[0].page_content
141
+ full_context += doc[0].page_content + "\n"
142
+ full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
143
+ full_context += doc[0].metadata["url"] + "\n\n"
144
+ source_context.append(doc[0].metadata["url"])
145
+
146
+ return text_context, full_context, source_context
147
+
148
+ def get_response(self, prompt: str, model_parameters: dict) -> str:
149
+ try:
150
+ docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
151
+ text_context, full_context, source = self.beautiful_context(docs)
152
+
153
+ del model_parameters["NUM_CHUNKS"]
154
+
155
+ response = self.predict_completion(prompt, text_context, model_parameters)
156
+
157
+ if not response:
158
+ return self.NO_ANSWER_MESSAGE
159
+
160
+ return response, full_context, source
161
+ except Exception as err:
162
+ print(err)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface-hub
3
+ python-dotenv
4
+ llama-index
5
+ llama-index-embeddings-huggingface
6
+ llama-index-llms-huggingface
7
+ sentence-transformers
8
+ langchain
9
+ faiss-cpu
10
+ aina-gradio-theme
11
+
12
+ langchain-community
13
+ langchain-core
14
+ openai
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ from dotenv import load_dotenv
5
+
6
+
7
+ from rag import RAG
8
+
9
+ USER_INPUT = 100
10
+
11
+
12
+ def setup():
13
+ load_dotenv()
14
+ warnings.filterwarnings("ignore")
15
+
16
+ logging.addLevelName(USER_INPUT, "USER_INPUT")
17
+ logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
18
+
19
+
20
+ def interactive(model: RAG):
21
+ logging.info("Write `exit` when you want to stop the model.")
22
+ print()
23
+
24
+ query = ""
25
+ while query.lower() != "exit":
26
+ logging.log(USER_INPUT, "Write the query or `exit`:")
27
+ query = input()
28
+
29
+ if query.lower() == "exit":
30
+ break
31
+
32
+ response = model.get_response(query)
33
+ print(response, end="\n\n")