Spaces:
Sleeping
Sleeping
File size: 6,288 Bytes
85504aa 525a9ab 845a31e 85504aa c0df2ba 525a9ab 85504aa 525a9ab 85504aa 8853ea0 85504aa c0df2ba 85504aa 525a9ab 85504aa 525a9ab 85504aa 525a9ab 85504aa c0df2ba 845a31e 525a9ab f095082 525a9ab 85504aa c0df2ba 525a9ab 85504aa 525a9ab 8853ea0 85504aa c0df2ba 525a9ab 85504aa c0df2ba 525a9ab 85504aa c0df2ba 85504aa c0df2ba 85504aa c0df2ba 85504aa 8853ea0 85504aa 525a9ab 85504aa 8853ea0 85504aa c0df2ba 85504aa 8853ea0 85504aa c0df2ba 85504aa 525a9ab 845a31e 525a9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import gradio as gr
from gradio import update as gr_update # tiny alias
from copy import deepcopy
from dotenv import load_dotenv
load_dotenv(override=True)
from rag.retrieval import search, ensure_ready
from rag.synth import synth_answer_stream
from helpers import _extract_cited_indices, linkify_text_with_sources, _group_sources_md, is_unknown_answer
# ---------- Warm-Up ----------
def _warmup():
try:
ensure_ready()
return "✅ Modèles initialisés !"
except Exception as e:
return f"⚠️ Warmup a échoué : {e}"
# ---------- Chat step 1: add user message ----------
def add_user(user_msg: str, history: list[tuple]) -> tuple[str, list[tuple]]:
user_msg = (user_msg or "").strip()
if not user_msg:
return "", history
# append a placeholder assistant turn for streaming
history = history + [(user_msg, "")]
return "", history
# ---------- Chat step 2: stream assistant answer ----------
def bot(history: list[tuple], api_key: str, top_k: int, model_name: str):
"""
Yields (history, sources_markdown) while streaming.
"""
if not history:
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
if api_key:
os.environ["OPENAI_API_KEY"] = api_key.strip()
user_msg, _ = history[-1]
# Retrieval
k = int(max(top_k, 1))
try:
hits = search(user_msg, top_k=k)
except Exception as e:
history[-1] = (user_msg, f"❌ Retrieval error: {e}")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# sources_md = sources_markdown(hits[:k])
# show a small “thinking” placeholder immediately
history[-1] = (user_msg, "⏳ Synthèse en cours…")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
# Streaming LLM
acc = ""
try:
for chunk in synth_answer_stream(user_msg, hits[:k], model=model_name):
acc += chunk or ""
step_hist = deepcopy(history)
step_hist[-1] = (user_msg, acc)
yield step_hist, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
except Exception as e:
history[-1] = (user_msg, f"❌ Synthèse: {e}")
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# Finalize + linkify citations
acc_linked = linkify_text_with_sources(acc, hits[:k])
history[-1] = (user_msg, acc_linked)
# Decide whether to show sources
if is_unknown_answer(acc_linked):
# No sources for unknown / reformulate
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
# Construit la section sources à partir des citations réelles [n]
used = _extract_cited_indices(acc_linked, k)
if not used:
yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
return
grouped_sources = _group_sources_md(hits[:k], used)
yield history, gr_update(visible=True, value=grouped_sources)
# yield history, sources_md
# ---------- UI ----------
with gr.Blocks(theme="soft", fill_height=True) as demo:
gr.Markdown("# 🇫🇷 Assistant RH — Chat RAG")
# Warmup status (put somewhere visible)
status = gr.Markdown("⏳ Initialisation des modèles du RAG…")
# Sidebar (no 'label' arg)
with gr.Sidebar(open=True):
gr.Markdown("## ⚙️ Paramètres")
api_key = gr.Textbox(
label="🔑 OpenAI API Key (BYOK — never stored)",
type="password",
placeholder="sk-… (optional if set in env)"
)
# let user choose the OpenAI model
model = gr.Dropdown(
label="⚙️ OpenAI model",
choices=[
"gpt-4o-mini",
"gpt-4o",
"gpt-4.1-mini",
"gpt-3.5-turbo"
],
value="gpt-4o-mini"
)
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
# you can wire this later; not used now
with gr.Row():
with gr.Column(scale=4):
chat = gr.Chatbot(
label="Chat Interface",
height="65vh",
show_copy_button=False,
avatar_images=(
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/icons/huggingface-logo.svg",
"assets/chatbot.png",
),
render_markdown=True,
show_label=False,
placeholder="<p style='text-align: center;'>Bonjour 👋,</p><p style='text-align: center;'>Je suis votre assistant HR. Je me tiens prêt à répondre à vos questions.</p>"
)
# input row
with gr.Row(equal_height=True):
msg = gr.Textbox(
placeholder="Posez votre question…",
show_label=False,
scale=5,
)
send = gr.Button("Envoyer", variant="primary", scale=1)
with gr.Column(scale=1):
sources = gr.Markdown("### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._")
state = gr.State([]) # chat history: list[tuple(user, assistant)]
# wire events: user submits -> add_user -> bot streams
send_click = send.click(add_user, [msg, state], [msg, state])
send_click.then(
bot,
[state, api_key, topk, model],
[chat, sources],
show_progress="minimal",
).then(lambda h: h, chat, state)
msg_submit = msg.submit(add_user, [msg, state], [msg, state])
msg_submit.then(
bot,
[state, api_key, topk, model],
[chat, sources],
show_progress="minimal",
).then(lambda h: h, chat, state)
demo.load(_warmup, inputs=None, outputs=status)
if __name__ == "__main__":
demo.queue().launch() |