Jens Grivolla commited on
Commit
f005840
·
1 Parent(s): 99162ec

make sys prompt configurable

Browse files
Files changed (2) hide show
  1. app.py +13 -7
  2. rag.py +5 -4
app.py CHANGED
@@ -22,9 +22,9 @@ rag = RAG(
22
  )
23
 
24
 
25
- def generate(prompt, model_parameters):
26
  try:
27
- output, context, source = rag.get_response(prompt, model_parameters)
28
  return output, context, source
29
  except HTTPError as err:
30
  if err.code == 400:
@@ -37,7 +37,7 @@ def generate(prompt, model_parameters):
37
  )
38
 
39
 
40
- def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
41
  if input_.strip() == "":
42
  gr.Warning("Not possible to inference an empty input")
43
  return None
@@ -53,7 +53,7 @@ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k,
53
  "temperature": temperature
54
  }
55
 
56
- output, context, source = generate(input_, model_parameters)
57
  sources_markup = ""
58
 
59
  for url in source:
@@ -112,6 +112,12 @@ def gradio_app():
112
  placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
113
  # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
114
  )
 
 
 
 
 
 
115
  with gr.Row(variant="panel"):
116
  clear_btn = Button(
117
  "Clear",
@@ -201,8 +207,8 @@ def gradio_app():
201
  inputs=[input_],
202
  api_name=False,
203
  js="""(i, m) => {
204
- document.getElementById('inputlenght').textContent = i.length + ' '
205
- document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
206
  }""",
207
  )
208
 
@@ -216,7 +222,7 @@ def gradio_app():
216
 
217
  submit_btn.click(
218
  fn=submit_input,
219
- inputs=[input_]+ parameters_compontents,
220
  outputs=[output, source_context, context_evaluation],
221
  api_name="get-results"
222
  )
 
22
  )
23
 
24
 
25
+ def generate(prompt, sys_prompt, model_parameters):
26
  try:
27
+ output, context, source = rag.get_response(prompt, sys_prompt, model_parameters)
28
  return output, context, source
29
  except HTTPError as err:
30
  if err.code == 400:
 
37
  )
38
 
39
 
40
+ def submit_input(input_, sysprompt_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
41
  if input_.strip() == "":
42
  gr.Warning("Not possible to inference an empty input")
43
  return None
 
53
  "temperature": temperature
54
  }
55
 
56
+ output, context, source = generate(input_, sysprompt_, model_parameters)
57
  sources_markup = ""
58
 
59
  for url in source:
 
112
  placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
113
  # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
114
  )
115
+ sysprompt_ = Textbox(
116
+ lines=2,
117
+ label="System",
118
+ placeholder="Below is a question that you should answer based on the given context. Write a response that answers the question using only information provided in the context.",
119
+ value = "Below is a question that you should answer based on the given context. Write a response that answers the question using only information provided in the context."
120
+ )
121
  with gr.Row(variant="panel"):
122
  clear_btn = Button(
123
  "Clear",
 
207
  inputs=[input_],
208
  api_name=False,
209
  js="""(i, m) => {
210
+ document.getElementById('inputlength').textContent = i.length + ' '
211
+ document.getElementById('inputlength').style.color = (i.length > m) ? "#ef4444" : "";
212
  }""",
213
  )
214
 
 
222
 
223
  submit_btn.click(
224
  fn=submit_input,
225
+ inputs=[input_, sysprompt_]+ parameters_compontents,
226
  outputs=[output, source_context, context_evaluation],
227
  api_name="get-results"
228
  )
rag.py CHANGED
@@ -32,7 +32,7 @@ class RAG:
32
 
33
  return documentos
34
 
35
- def predict(self, instruction, context, model_parameters):
36
 
37
  from openai import OpenAI
38
 
@@ -42,9 +42,10 @@ class RAG:
42
  api_key=os.getenv("HF_TOKEN")
43
  )
44
 
45
- sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
46
  #query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
47
  query = f"{sys_prompt}\n\nContext:\n{context}\n\nQuestion:\n{instruction}"
 
48
  #query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
49
  chat_completion = client.chat.completions.create(
50
  model="tgi",
@@ -77,14 +78,14 @@ class RAG:
77
 
78
  return text_context, full_context, source_context
79
 
80
- def get_response(self, prompt: str, model_parameters: dict) -> str:
81
  try:
82
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
83
  text_context, full_context, source = self.beautiful_context(docs)
84
 
85
  del model_parameters["NUM_CHUNKS"]
86
 
87
- response = self.predict(prompt, text_context, model_parameters)
88
 
89
  if not response:
90
  return self.NO_ANSWER_MESSAGE
 
32
 
33
  return documentos
34
 
35
+ def predict(self, instruction, sys_prompt, context, model_parameters):
36
 
37
  from openai import OpenAI
38
 
 
42
  api_key=os.getenv("HF_TOKEN")
43
  )
44
 
45
+ #sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
46
  #query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
47
  query = f"{sys_prompt}\n\nContext:\n{context}\n\nQuestion:\n{instruction}"
48
+ print(query)
49
  #query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
50
  chat_completion = client.chat.completions.create(
51
  model="tgi",
 
78
 
79
  return text_context, full_context, source_context
80
 
81
+ def get_response(self, prompt: str, sys_prompt: str, model_parameters: dict) -> str:
82
  try:
83
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
84
  text_context, full_context, source = self.beautiful_context(docs)
85
 
86
  del model_parameters["NUM_CHUNKS"]
87
 
88
+ response = self.predict(prompt, sys_prompt, text_context, model_parameters)
89
 
90
  if not response:
91
  return self.NO_ANSWER_MESSAGE