edouardfoussier commited on
Commit
8853ea0
·
1 Parent(s): f095082

added Model dropdown in sidebar

Browse files
Files changed (2) hide show
  1. app.py +15 -4
  2. rag/synth.py +16 -4
app.py CHANGED
@@ -32,7 +32,7 @@ def add_user(user_msg: str, history: list[tuple]) -> tuple[str, list[tuple]]:
32
 
33
 
34
  # ---------- Chat step 2: stream assistant answer ----------
35
- def bot(history: list[tuple], api_key: str, top_k: int):
36
  """
37
  Yields (history, sources_markdown) while streaming.
38
  """
@@ -63,7 +63,7 @@ def bot(history: list[tuple], api_key: str, top_k: int):
63
  # Streaming LLM
64
  acc = ""
65
  try:
66
- for chunk in synth_answer_stream(user_msg, hits[:k]):
67
  acc += chunk or ""
68
  step_hist = deepcopy(history)
69
  step_hist[-1] = (user_msg, acc)
@@ -109,6 +109,17 @@ with gr.Blocks(theme="soft", fill_height=True) as demo:
109
  type="password",
110
  placeholder="sk-… (optional if set in env)"
111
  )
 
 
 
 
 
 
 
 
 
 
 
112
  topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
113
  # you can wire this later; not used now
114
 
@@ -144,7 +155,7 @@ with gr.Blocks(theme="soft", fill_height=True) as demo:
144
  send_click = send.click(add_user, [msg, state], [msg, state])
145
  send_click.then(
146
  bot,
147
- [state, api_key, topk],
148
  [chat, sources],
149
  show_progress="minimal",
150
  ).then(lambda h: h, chat, state)
@@ -152,7 +163,7 @@ with gr.Blocks(theme="soft", fill_height=True) as demo:
152
  msg_submit = msg.submit(add_user, [msg, state], [msg, state])
153
  msg_submit.then(
154
  bot,
155
- [state, api_key, topk],
156
  [chat, sources],
157
  show_progress="minimal",
158
  ).then(lambda h: h, chat, state)
 
32
 
33
 
34
  # ---------- Chat step 2: stream assistant answer ----------
35
+ def bot(history: list[tuple], api_key: str, top_k: int, model_name: str):
36
  """
37
  Yields (history, sources_markdown) while streaming.
38
  """
 
63
  # Streaming LLM
64
  acc = ""
65
  try:
66
+ for chunk in synth_answer_stream(user_msg, hits[:k], model=model_name):
67
  acc += chunk or ""
68
  step_hist = deepcopy(history)
69
  step_hist[-1] = (user_msg, acc)
 
109
  type="password",
110
  placeholder="sk-… (optional if set in env)"
111
  )
112
+ # let user choose the OpenAI model
113
+ model = gr.Dropdown(
114
+ label="⚙️ OpenAI model",
115
+ choices=[
116
+ "gpt-4o-mini",
117
+ "gpt-4o",
118
+ "gpt-4.1-mini",
119
+ "gpt-3.5-turbo"
120
+ ],
121
+ value="gpt-4o-mini"
122
+ )
123
  topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
124
  # you can wire this later; not used now
125
 
 
155
  send_click = send.click(add_user, [msg, state], [msg, state])
156
  send_click.then(
157
  bot,
158
+ [state, api_key, topk, model],
159
  [chat, sources],
160
  show_progress="minimal",
161
  ).then(lambda h: h, chat, state)
 
163
  msg_submit = msg.submit(add_user, [msg, state], [msg, state])
164
  msg_submit.then(
165
  bot,
166
+ [state, api_key, topk, model],
167
  [chat, sources],
168
  show_progress="minimal",
169
  ).then(lambda h: h, chat, state)
rag/synth.py CHANGED
@@ -8,6 +8,14 @@ from datetime import date
8
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
9
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
10
 
 
 
 
 
 
 
 
 
11
  def _build_prompt(query, passages):
12
 
13
  # Construire des blocs numérotés et balisés
@@ -26,12 +34,15 @@ def _build_prompt(query, passages):
26
 
27
  context = "\n\n".join(blocks)
28
  query = utf8_safe(query)
29
- today = date.today().strftime("%d %B %Y") # e.g. "27 août 2025"
30
-
31
 
32
  return (
33
  "Tu es un assistant RH chargé de répondre à des questions dans le domaine des ressources humaines en t'appuyant sur les sources fournies.\n"
34
- "La date d'aujourd'hui est : {today}.\n\n"
 
 
 
 
35
  "Consignes :\n"
36
  "- Réponds de manière factuelle, concise et polie (vouvoiement).\n"
37
  "- Quand tu affirmes un fait, cite tes sources en fin de phrase avec le format [1], [2]… en te basant sur l'index de ces sources (ex: [1] est la source 1, [2] est la source 2, etc.)\n\n"
@@ -43,8 +54,9 @@ def _build_prompt(query, passages):
43
  "Réponse:"
44
  )
45
 
46
- def synth_answer_stream(query, passages):
47
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
 
48
  prompt = utf8_safe(_build_prompt(query, passages))
49
 
50
  stream = client.chat.completions.create(
 
8
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
9
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
10
 
11
+ # _MONTHS_FR = {
12
+ # 1:"janvier", 2:"février", 3:"mars", 4:"avril", 5:"mai", 6:"juin",
13
+ # 7:"juillet", 8:"août", 9:"septembre", 10:"octobre", 11:"novembre", 12:"décembre"
14
+ # }
15
+ # def today_fr():
16
+ # d = date.today()
17
+ # return f"{d.day} {_MONTHS_FR[d.month]} {d.year}"
18
+
19
  def _build_prompt(query, passages):
20
 
21
  # Construire des blocs numérotés et balisés
 
34
 
35
  context = "\n\n".join(blocks)
36
  query = utf8_safe(query)
37
+ today = date.today().strftime("%d %m %Y") # e.g. "27 août 2025"
 
38
 
39
  return (
40
  "Tu es un assistant RH chargé de répondre à des questions dans le domaine des ressources humaines en t'appuyant sur les sources fournies.\n"
41
+ f"La date d'aujourd'hui est : {today} (au format AAAA-MM-JJ).\n\n"
42
+ "⚠️ Consigne temporelle : Les textes sources peuvent avoir été rédigés avant aujourd'hui "
43
+ "et mentionner des changements à venir. Interprète ces formulations en fonction de la date actuelle. "
44
+ "Si une mesure annoncée est déjà en vigueur aujourd'hui, écris-la au présent ou au passé, "
45
+ "jamais au futur.\n\n"
46
  "Consignes :\n"
47
  "- Réponds de manière factuelle, concise et polie (vouvoiement).\n"
48
  "- Quand tu affirmes un fait, cite tes sources en fin de phrase avec le format [1], [2]… en te basant sur l'index de ces sources (ex: [1] est la source 1, [2] est la source 2, etc.)\n\n"
 
54
  "Réponse:"
55
  )
56
 
57
+ def synth_answer_stream(query, passages, model: str | None = None):
58
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
59
+ model = model or LLM_MODEL
60
  prompt = utf8_safe(_build_prompt(query, passages))
61
 
62
  stream = client.chat.completions.create(