ankush13r commited on
Commit
929fe0b
·
verified ·
1 Parent(s): 9c18562

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +5 -43
rag.py CHANGED
@@ -32,37 +32,7 @@ class RAG:
32
 
33
  return documentos
34
 
35
- def predict(self, instruction, sys_prompt, context, model_parameters):
36
-
37
- from openai import OpenAI
38
-
39
- # init the client but point it to TGI
40
- client = OpenAI(
41
- base_url=os.getenv("MODEL")+ "/v1/",
42
- api_key=os.getenv("HF_TOKEN")
43
- )
44
-
45
- #sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
46
- #query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
47
- query = f"Context:\n{context}\n\nQuestion:\n{instruction}\n\n{sys_prompt}"
48
- print(query)
49
- #query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
50
- chat_completion = client.chat.completions.create(
51
- model="tgi",
52
- messages=[
53
- #{"role": "system", "content": sys_prompt },
54
- {"role": "user", "content": query}
55
- ],
56
- max_tokens=model_parameters['max_new_tokens'], # TODO: map other parameters
57
- frequency_penalty=model_parameters['repetition_penalty'], # this doesn't appear to do much, not a replacement for repetition penalty
58
- # presence_penalty=model_parameters['repetition_penalty'],
59
- # extra_body=model_parameters,
60
- stream=False,
61
- stop=["<|im_end|>", "<|end_header_id|>", "<|eot_id|>", "<|reserved_special_token"]
62
- )
63
- return(chat_completion.choices[0].message.content)
64
-
65
-
66
  def beautiful_context(self, docs):
67
 
68
  text_context = ""
@@ -71,25 +41,17 @@ class RAG:
71
  source_context = []
72
  for doc in docs:
73
  text_context += doc[0].page_content
74
- full_context += doc[0].page_content + "\n"
75
  full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
76
  full_context += doc[0].metadata["url"] + "\n\n"
 
77
  source_context.append(doc[0].metadata["url"])
78
 
79
  return text_context, full_context, source_context
80
 
81
- def get_response(self, prompt: str, sys_prompt: str, model_parameters: dict) -> str:
82
  try:
83
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
84
- text_context, full_context, source = self.beautiful_context(docs)
85
-
86
- del model_parameters["NUM_CHUNKS"]
87
-
88
- response = self.predict(prompt, sys_prompt, text_context, model_parameters)
89
-
90
- if not response:
91
- return self.NO_ANSWER_MESSAGE
92
-
93
- return response, full_context, source
94
  except Exception as err:
95
  print(err)
 
 
32
 
33
  return documentos
34
 
35
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def beautiful_context(self, docs):
37
 
38
  text_context = ""
 
41
  source_context = []
42
  for doc in docs:
43
  text_context += doc[0].page_content
 
44
  full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
45
  full_context += doc[0].metadata["url"] + "\n\n"
46
+ full_context += doc[0].page_content + "\n"
47
  source_context.append(doc[0].metadata["url"])
48
 
49
  return text_context, full_context, source_context
50
 
51
+ def get_context(self, prompt: str, model_parameters: dict) -> str:
52
  try:
53
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
54
+ return self.beautiful_context(docs)
 
 
 
 
 
 
 
 
 
55
  except Exception as err:
56
  print(err)
57
+ return None, None, None