edouardfoussier's picture
added Model dropdown in sidebar
8853ea0
raw
history blame
6.29 kB
import os
import gradio as gr
from gradio import update as gr_update # tiny alias
from copy import deepcopy
from dotenv import load_dotenv
load_dotenv(override=True)
from rag.retrieval import search, ensure_ready
from rag.synth import synth_answer_stream
from helpers import _extract_cited_indices, linkify_text_with_sources, _group_sources_md, is_unknown_answer
# ---------- Warm-Up ----------
def _warmup():
try:
ensure_ready()
return "✅ Modèles initialisés !"
except Exception as e:
return f"⚠️ Warmup a échoué : {e}"
# ---------- Chat step 1: add user message ----------
def add_user(user_msg: str, history: list[tuple]) -> tuple[str, list[tuple]]:
user_msg = (user_msg or "").strip()
if not user_msg:
return "", history
# append a placeholder assistant turn for streaming
history = history + [(user_msg, "")]
return "", history
# ---------- Chat step 2: stream assistant answer ----------
def bot(history: list[tuple], api_key: str, top_k: int, model_name: str):
"""
Yields (history, sources_markdown) while streaming.
"""
if not history:
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
if api_key:
os.environ["OPENAI_API_KEY"] = api_key.strip()
user_msg, _ = history[-1]
# Retrieval
k = int(max(top_k, 1))
try:
hits = search(user_msg, top_k=k)
except Exception as e:
history[-1] = (user_msg, f"❌ Retrieval error: {e}")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# sources_md = sources_markdown(hits[:k])
# show a small “thinking” placeholder immediately
history[-1] = (user_msg, "⏳ Synthèse en cours…")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
# Streaming LLM
acc = ""
try:
for chunk in synth_answer_stream(user_msg, hits[:k], model=model_name):
acc += chunk or ""
step_hist = deepcopy(history)
step_hist[-1] = (user_msg, acc)
yield step_hist, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
except Exception as e:
history[-1] = (user_msg, f"❌ Synthèse: {e}")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# Finalize + linkify citations
acc_linked = linkify_text_with_sources(acc, hits[:k])
history[-1] = (user_msg, acc_linked)
# Decide whether to show sources
if is_unknown_answer(acc_linked):
# No sources for unknown / reformulate
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# Construit la section sources à partir des citations réelles [n]
used = _extract_cited_indices(acc_linked, k)
if not used:
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
grouped_sources = _group_sources_md(hits[:k], used)
yield history, gr_update(visible=True, value=grouped_sources)
# yield history, sources_md
# ---------- UI ----------
with gr.Blocks(theme="soft", fill_height=True) as demo:
gr.Markdown("# 🇫🇷 Assistant RH — Chat RAG")
# Warmup status (put somewhere visible)
status = gr.Markdown("⏳ Initialisation des modèles du RAG…")
# Sidebar (no 'label' arg)
with gr.Sidebar(open=True):
gr.Markdown("## ⚙️ Paramètres")
api_key = gr.Textbox(
label="🔑 OpenAI API Key (BYOK — never stored)",
type="password",
placeholder="sk-… (optional if set in env)"
)
# let user choose the OpenAI model
model = gr.Dropdown(
label="⚙️ OpenAI model",
choices=[
"gpt-4o-mini",
"gpt-4o",
"gpt-4.1-mini",
"gpt-3.5-turbo"
],
value="gpt-4o-mini"
)
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
# you can wire this later; not used now
with gr.Row():
with gr.Column(scale=4):
chat = gr.Chatbot(
label="Chat Interface",
height="65vh",
show_copy_button=False,
avatar_images=(
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/icons/huggingface-logo.svg",
"assets/chatbot.png",
),
render_markdown=True,
show_label=False,
placeholder="<p style='text-align: center;'>Bonjour 👋,</p><p style='text-align: center;'>Je suis votre assistant HR. Je me tiens prêt à répondre à vos questions.</p>"
)
# input row
with gr.Row(equal_height=True):
msg = gr.Textbox(
placeholder="Posez votre question…",
show_label=False,
scale=5,
)
send = gr.Button("Envoyer", variant="primary", scale=1)
with gr.Column(scale=1):
sources = gr.Markdown("### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._")
state = gr.State([]) # chat history: list[tuple(user, assistant)]
# wire events: user submits -> add_user -> bot streams
send_click = send.click(add_user, [msg, state], [msg, state])
send_click.then(
bot,
[state, api_key, topk, model],
[chat, sources],
show_progress="minimal",
).then(lambda h: h, chat, state)
msg_submit = msg.submit(add_user, [msg, state], [msg, state])
msg_submit.then(
bot,
[state, api_key, topk, model],
[chat, sources],
show_progress="minimal",
).then(lambda h: h, chat, state)
demo.load(_warmup, inputs=None, outputs=status)
if __name__ == "__main__":
demo.queue().launch()