File size: 3,742 Bytes
2217335
 
 
 
 
 
 
 
 
 
 
f5848c0
2217335
d878e42
1cdd61e
033e5a9
2217335
 
 
 
 
 
 
 
a5d9210
2217335
 
 
708527c
2217335
 
 
c774338
2217335
f005840
2217335
c42abf4
2217335
c42abf4
 
 
 
 
 
f005840
99162ec
9c18562
f005840
99162ec
c42abf4
 
 
 
 
 
78b65fc
622a0d7
dc8f8ea
 
 
3bb55c2
c42abf4
d0ef465
2217335
c774338
 
 
 
 
 
1983ef1
c774338
 
 
843fee2
084159d
1983ef1
c774338
e742531
2217335
f005840
e742531
 
 
c774338
e742531
2217335
f005840
2217335
e742531
 
2217335
e742531
 
033e5a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import logging
import os
import requests



from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings


class RAG:
    NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."

    #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed
    #vectorstore = "vectorestore" # CA only
    vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"

    def __init__(self, hf_token, embeddings_model, model_name):

        self.model_name = model_name
        self.hf_token = hf_token
        
        # load vectore store
        embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
        self.vectore_store = FAISS.load_local(self.vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)

        logging.info("RAG loaded!")
    
    def get_context(self, instruction, number_of_contexts=2):

        documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)

        return documentos
        
    def predict(self, instruction, sys_prompt, context, model_parameters):

        from openai import OpenAI
        
        # init the client but point it to TGI
        client = OpenAI(
            base_url=os.getenv("MODEL")+ "/v1/",
            api_key=os.getenv("HF_TOKEN")
        )

        #sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
        #query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
        query = f"Context:\n{context}\n\nQuestion:\n{instruction}\n\n{sys_prompt}"
        print(query)
        #query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
        chat_completion = client.chat.completions.create(
            model="tgi",
            messages=[
                #{"role": "system", "content": sys_prompt },
                {"role": "user", "content": query}
            ],
            max_tokens=model_parameters['max_new_tokens'], # TODO: map other parameters
            frequency_penalty=model_parameters['repetition_penalty'], # this doesn't appear to do much, not a replacement for repetition penalty
            # presence_penalty=model_parameters['repetition_penalty'],
            # extra_body=model_parameters,
            stream=False,
            stop=["<|im_end|>", "<|end_header_id|>", "<|eot_id|>", "<|reserved_special_token"]
        )
        return(chat_completion.choices[0].message.content)

    
    def beautiful_context(self, docs):

        text_context = ""

        full_context = ""
        source_context = []
        for doc in docs:
            text_context += doc[0].page_content
            full_context += doc[0].page_content + "\n"
            full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
            full_context += doc[0].metadata["url"] + "\n\n"
            source_context.append(doc[0].metadata["url"])

        return text_context, full_context, source_context

    def get_response(self, prompt: str, sys_prompt: str, model_parameters: dict) -> str:
        try:
            docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
            text_context, full_context, source = self.beautiful_context(docs)

            del model_parameters["NUM_CHUNKS"]

            response = self.predict(prompt, sys_prompt, text_context, model_parameters)

            if not response:
                return self.NO_ANSWER_MESSAGE

            return response, full_context, source
        except Exception as err:
            print(err)