File size: 6,288 Bytes
85504aa
 
 
 
525a9ab
 
 
845a31e
85504aa
 
c0df2ba
525a9ab
 
85504aa
525a9ab
85504aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8853ea0
85504aa
 
 
 
c0df2ba
85504aa
525a9ab
85504aa
 
525a9ab
85504aa
525a9ab
85504aa
 
 
 
 
 
c0df2ba
845a31e
525a9ab
f095082
525a9ab
85504aa
 
c0df2ba
525a9ab
85504aa
525a9ab
 
8853ea0
85504aa
 
 
c0df2ba
525a9ab
85504aa
c0df2ba
525a9ab
 
85504aa
 
 
 
c0df2ba
 
 
 
 
 
85504aa
 
c0df2ba
 
 
 
85504aa
 
c0df2ba
85504aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8853ea0
 
 
 
 
 
 
 
 
 
 
85504aa
 
525a9ab
 
85504aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8853ea0
85504aa
c0df2ba
85504aa
 
 
 
 
8853ea0
85504aa
c0df2ba
85504aa
 
 
 
525a9ab
845a31e
 
525a9ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import gradio as gr
from gradio import update as gr_update  # tiny alias
from copy import deepcopy
from dotenv import load_dotenv

load_dotenv(override=True)

from rag.retrieval import search, ensure_ready
from rag.synth import synth_answer_stream
from helpers import _extract_cited_indices, linkify_text_with_sources, _group_sources_md, is_unknown_answer


# ---------- Warm-Up ----------

def _warmup():
    try:
        ensure_ready()
        return "✅ Modèles initialisés !"
    except Exception as e:
        return f"⚠️ Warmup a échoué : {e}"


# ---------- Chat step 1: add user message ----------
def add_user(user_msg: str, history: list[tuple]) -> tuple[str, list[tuple]]:
    user_msg = (user_msg or "").strip()
    if not user_msg:
        return "", history
    # append a placeholder assistant turn for streaming
    history = history + [(user_msg, "")]
    return "", history


# ---------- Chat step 2: stream assistant answer ----------
def bot(history: list[tuple], api_key: str, top_k: int, model_name: str):
    """
    Yields (history, sources_markdown) while streaming.
    """
    if not history:
        yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
        return

    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key.strip()

    user_msg, _ = history[-1]

    # Retrieval
    k = int(max(top_k, 1))
    try:
        hits = search(user_msg, top_k=k)
    except Exception as e:
        history[-1] = (user_msg, f"❌ Retrieval error: {e}")
        yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
        return

    # sources_md = sources_markdown(hits[:k])

    # show a small “thinking” placeholder immediately
    history[-1] = (user_msg, "⏳ Synthèse en cours…")
    yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"

    # Streaming LLM
    acc = ""
    try:
        for chunk in synth_answer_stream(user_msg, hits[:k], model=model_name):
            acc += chunk or ""
            step_hist = deepcopy(history)
            step_hist[-1] = (user_msg, acc)
            yield step_hist, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
    except Exception as e:
        history[-1] = (user_msg, f"❌ Synthèse: {e}")
        yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
        return

    # Finalize + linkify citations
    acc_linked = linkify_text_with_sources(acc, hits[:k])
    history[-1] = (user_msg, acc_linked)
    
    # Decide whether to show sources
    if is_unknown_answer(acc_linked):
        # No sources for unknown / reformulate
        yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
        return
    
    # Construit la section sources à partir des citations réelles [n]
    used = _extract_cited_indices(acc_linked, k)
    if not used:
        yield history, "### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._"
        return
    
    grouped_sources = _group_sources_md(hits[:k], used)

    yield history, gr_update(visible=True, value=grouped_sources)
    # yield history, sources_md


# ---------- UI ----------
with gr.Blocks(theme="soft", fill_height=True) as demo:
    gr.Markdown("# 🇫🇷 Assistant RH — Chat RAG")
            # Warmup status (put somewhere visible)
    status = gr.Markdown("⏳ Initialisation des modèles du RAG…")

    # Sidebar (no 'label' arg)
    with gr.Sidebar(open=True):
        gr.Markdown("## ⚙️ Paramètres")
        api_key = gr.Textbox(
            label="🔑 OpenAI API Key (BYOK — never stored)",
            type="password",
            placeholder="sk-… (optional if set in env)"
        )
        # let user choose the OpenAI model
        model = gr.Dropdown(
            label="⚙️ OpenAI model",
            choices=[
                "gpt-4o-mini",
                "gpt-4o",
                "gpt-4.1-mini",
                "gpt-3.5-turbo"
            ],
            value="gpt-4o-mini"
        )
        topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
        # you can wire this later; not used now

    with gr.Row():
        with gr.Column(scale=4):
            chat = gr.Chatbot(
                label="Chat Interface",
                height="65vh",
                show_copy_button=False,
                avatar_images=(
                    "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/icons/huggingface-logo.svg",
                    "assets/chatbot.png",
                ),
                render_markdown=True,
                show_label=False,
                placeholder="<p style='text-align: center;'>Bonjour 👋,</p><p style='text-align: center;'>Je suis votre assistant HR. Je me tiens prêt à répondre à vos questions.</p>"
            )
            # input row
            with gr.Row(equal_height=True):
                msg = gr.Textbox(
                    placeholder="Posez votre question…",
                    show_label=False,
                    scale=5,
                )
                send = gr.Button("Envoyer", variant="primary", scale=1)

        with gr.Column(scale=1):
            sources = gr.Markdown("### 📚 Sources\n_Ici, vous pourrez consulter les sources utilisées pour formuler la réponse._")

    state = gr.State([])  # chat history: list[tuple(user, assistant)]

    # wire events: user submits -> add_user -> bot streams
    send_click = send.click(add_user, [msg, state], [msg, state])
    send_click.then(
        bot,
        [state, api_key, topk, model],
        [chat, sources],
        show_progress="minimal",
    ).then(lambda h: h, chat, state)

    msg_submit = msg.submit(add_user, [msg, state], [msg, state])
    msg_submit.then(
        bot,
        [state, api_key, topk, model],
        [chat, sources],
        show_progress="minimal",
    ).then(lambda h: h, chat, state)


    demo.load(_warmup, inputs=None, outputs=status)


if __name__ == "__main__":
    demo.queue().launch()