edouardfoussier commited on
Commit
85504aa
Β·
1 Parent(s): 525a9ab

big update of custom chatbot + sidebar layout + sources generation

Browse files
Files changed (6) hide show
  1. app.py +131 -122
  2. assets/chatbot.png +3 -0
  3. helpers.py +109 -7
  4. rag/retrieval.py +40 -26
  5. rag/synth.py +40 -139
  6. rag/utils.py +11 -0
app.py CHANGED
@@ -1,147 +1,156 @@
1
- import os, time
 
 
 
2
  from dotenv import load_dotenv
3
 
4
- # Load environment variables BEFORE importing rag modules
5
  load_dotenv(override=True)
6
 
7
- import gradio as gr
8
- from rag.retrieval import search, embed
9
- from rag.synth import synth_answer_stream, render_sources
10
- from helpers import linkify_text_with_sources
11
-
12
- missing = []
13
- if not os.getenv("HF_API_TOKEN"): missing.append("HF_API_TOKEN (embeddings)")
14
- if not os.getenv("LLM_MODEL"): print("[INFO] LLM_MODEL not set, using default", flush=True)
15
- print("[ENV] Missing:", ", ".join(missing) or "None", flush=True)
16
- # HF_API_TOKEN = os.getenv("HF_API_TOKEN")
17
-
18
- # def sanity():
19
- # ok = bool(os.getenv("HF_API_TOKEN"))
20
- # v = embed("hello world")
21
- # return f"Token set? {ok}\nEmbedding dim: {len(v)}"
22
-
23
- # def rag_chat(user_question, openai_key):
24
- # if not openai_key:
25
- # return "❌ Please provide your OpenAI API key."
26
-
27
- # # Inject the key into environment so synth can use it
28
- # os.environ["OPENAI_API_KEY"] = openai_key
29
-
30
- # # Step 1: Retrieve top passages
31
- # hits = search(user_question, top_k=8)
32
-
33
- # if not hits:
34
- # return "❌ Sorry, no relevant information found."
35
-
36
- # # Step 2: Generate synthesized answer
37
- # try:
38
- # final_answer = synth_answer(user_question, hits[:5])
39
- # final_answer = linkify(final_answer, hits[:5])
40
- # final_answer += "\n\n---\n" + render_sources(hits[:5])
41
- # except Exception as e:
42
- # final_answer = f"❌ Error during synthesis: {e}"
43
-
44
- # return final_answer
45
- # def rag_chat(user_question, openai_key):
46
- # if not openai_key:
47
- # yield "❌ Please provide your OpenAI API key."
48
- # return
49
-
50
- # os.environ["OPENAI_API_KEY"] = openai_key
51
-
52
- # hits = search(user_question, top_k=8)
53
- # if not hits:
54
- # yield "❌ Sorry, no relevant information found."
55
- # return
56
-
57
- # acc = ""
58
- # try:
59
- # for piece in synth_answer_stream(user_question, hits[:5]):
60
- # acc += piece or ""
61
- # # stream raw text while typing (no links yet to avoid jumpiness)
62
- # yield acc
63
- # except Exception as e:
64
- # partial = acc if acc.strip() else ""
65
- # yield (partial + ("\n\n" if partial else "") + f"❌ Streaming error: {e}")
66
- # return
67
-
68
- # final_md = linkify_text_with_sources(acc, hits[:5])
69
- # yield final_md
70
-
71
-
72
-
73
- # with gr.Blocks() as demo:
74
- # gr.Markdown("## πŸ€– HR Assistant (RAG)\nAsk your question below:")
75
-
76
- # with gr.Row():
77
- # api_key = gr.Textbox(label="πŸ”‘ Your OpenAI API Key", type="password")
78
-
79
- # question = gr.Textbox(label="❓ Your Question", placeholder="e.g., Quels sont les droits Γ  congΓ©s ?")
80
-
81
- # answer = gr.Markdown(label="πŸ’‘ Assistant Answer")
82
 
83
- # submit_btn = gr.Button("Ask")
84
 
85
- # submit_btn.click(fn=rag_chat, inputs=[question, api_key], outputs=answer)
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # if __name__ == "__main__":
89
- # demo.launch()
90
 
 
91
 
92
- def rag_chat(user_question: str, openai_key: str):
93
- """Generator: streams draft text to a Textbox, then yields final Markdown."""
94
- if not openai_key:
95
- yield "❌ Please provide your OpenAI API key.", None
 
 
 
96
  return
97
 
98
- os.environ["OPENAI_API_KEY"] = openai_key.strip()
99
 
100
- # Step 1: retrieve
101
- yield "⏳ Recherche des passages pertinents…", None
102
- hits = search(user_question, top_k=8)
103
- if not hits:
104
- yield "❌ Sorry, no relevant information found.", None
105
- return
106
 
107
- # Step 2: stream LLM synthesis
108
  acc = ""
109
  try:
110
- for piece in synth_answer_stream(user_question, hits[:5]):
111
- acc += piece or ""
112
- # Stream into the draft textbox; keep markdown empty during typing
113
- yield acc, None
 
114
  except Exception as e:
115
- yield f"❌ Error during synthesis: {e}", None
 
116
  return
117
 
118
- # Step 3: finalize + linkify citations in Markdown block
119
- md = linkify_text_with_sources(acc, hits[:5])
120
- yield acc, md
121
-
122
- with gr.Blocks() as demo:
123
- gr.Markdown("## πŸ€– HR Assistant (RAG)\nAsk your question below:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  with gr.Row():
126
- api_key = gr.Textbox(label="πŸ”‘ Your OpenAI API Key", type="password", placeholder="sk-…")
127
- question = gr.Textbox(label="❓ Your Question", placeholder="e.g., Quels sont les droits Γ  congΓ©s ?")
128
-
129
- # live streaming target
130
- draft_answer = gr.Markdown(label="πŸ’¬ RΓ©ponse")
131
- # final pretty markdown with clickable links
132
- # final_answer = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- with gr.Row():
135
- submit_btn = gr.Button("Ask", variant="primary")
136
- clear_btn = gr.Button("Clear")
137
-
138
- submit_btn.click(
139
- fn=rag_chat,
140
- inputs=[question, api_key],
141
- outputs=[draft_answer, final_answer],
142
- show_progress="full", # shows loader on the button
143
- )
144
- clear_btn.click(lambda: ("", ""), outputs=[draft_answer, final_answer])
145
 
146
  if __name__ == "__main__":
147
  demo.queue().launch()
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio import update as gr_update # tiny alias
4
+ from copy import deepcopy
5
  from dotenv import load_dotenv
6
 
 
7
  load_dotenv(override=True)
8
 
9
+ from rag.retrieval import search, ensure_ready
10
+ from rag.synth import synth_answer_stream
11
+ from helpers import _extract_cited_indices, linkify_text_with_sources, _group_sources_md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
 
14
+ # ---------- Warm-Up ----------
15
 
16
+ def _warmup():
17
+ try:
18
+ ensure_ready()
19
+ return "βœ… ModΓ¨les initialisΓ©s !"
20
+ except Exception as e:
21
+ return f"⚠️ Warmup a échoué : {e}"
22
+
23
+
24
+ # ---------- Chat step 1: add user message ----------
25
+ def add_user(user_msg: str, history: list[tuple]) -> tuple[str, list[tuple]]:
26
+ user_msg = (user_msg or "").strip()
27
+ if not user_msg:
28
+ return "", history
29
+ # append a placeholder assistant turn for streaming
30
+ history = history + [(user_msg, "")]
31
+ return "", history
32
+
33
+
34
+ # ---------- Chat step 2: stream assistant answer ----------
35
+ def bot(history: list[tuple], api_key: str, top_k: int):
36
+ """
37
+ Yields (history, sources_markdown) while streaming.
38
+ """
39
+ if not history:
40
+ yield history, "### Sources\n_(none)_"
41
+ return
42
 
43
+ if api_key:
44
+ os.environ["OPENAI_API_KEY"] = api_key.strip()
45
 
46
+ user_msg, _ = history[-1]
47
 
48
+ # Retrieval
49
+ k = int(max(top_k, 1))
50
+ try:
51
+ hits = search(user_msg, top_k=k)
52
+ except Exception as e:
53
+ history[-1] = (user_msg, f"❌ Retrieval error: {e}")
54
+ yield history, "### Sources\n_(none)_"
55
  return
56
 
57
+ sources_md = sources_markdown(hits[:k])
58
 
59
+ # show a small β€œthinking” placeholder immediately
60
+ history[-1] = (user_msg, "⏳ SynthΓ¨se en cours…")
61
+ yield history, "### πŸ“š Sources"
 
 
 
62
 
63
+ # Streaming LLM
64
  acc = ""
65
  try:
66
+ for chunk in synth_answer_stream(user_msg, hits[:k]):
67
+ acc += chunk or ""
68
+ step_hist = deepcopy(history)
69
+ step_hist[-1] = (user_msg, acc)
70
+ yield step_hist, "### πŸ“š Sources"
71
  except Exception as e:
72
+ history[-1] = (user_msg, f"❌ Synthèse: {e}")
73
+ yield history, sources_md
74
  return
75
 
76
+ # Finalize + linkify citations
77
+ acc_linked = linkify_text_with_sources(acc, hits[:k])
78
+ history[-1] = (user_msg, acc_linked)
79
+
80
+ # Construit la section sources Γ  partir des citations rΓ©elles [n]
81
+ used = _extract_cited_indices(acc_linked, k)
82
+ grouped_sources = _group_sources_md(hits[:k], used)
83
+
84
+ yield history, grouped_sources
85
+ # yield history, sources_md
86
+
87
+
88
+ # ---------- UI ----------
89
+ with gr.Blocks(theme="soft", fill_height=True) as demo:
90
+ gr.Markdown("# πŸ‡«πŸ‡· Assistant RH β€” Chat RAG")
91
+ # Warmup status (put somewhere visible)
92
+ status = gr.Markdown("⏳ Initialisation des modΓ¨les du RAG…")
93
+
94
+ # Sidebar (no 'label' arg)
95
+ with gr.Sidebar(open=True):
96
+ gr.Markdown("## βš™οΈ ParamΓ¨tres")
97
+ api_key = gr.Textbox(
98
+ label="πŸ”‘ OpenAI API Key (BYOK β€” never stored)",
99
+ type="password",
100
+ placeholder="sk-… (optional if set in env)"
101
+ )
102
+ topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
103
+ # you can wire this later; not used now
104
+ save_history = gr.Checkbox(label="Ajouter un modèle eranker")
105
 
106
  with gr.Row():
107
+ with gr.Column(scale=4):
108
+ chat = gr.Chatbot(
109
+ label="Chat Interface",
110
+ height="65vh",
111
+ show_copy_button=False,
112
+ avatar_images=(
113
+ "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/icons/huggingface-logo.svg",
114
+ "assets/chatbot.png",
115
+ ),
116
+ render_markdown=True,
117
+ show_label=False,
118
+ placeholder="<p style='text-align: center;'>Bonjour πŸ‘‹,</p><p style='text-align: center;'>Je suis votre assistant HR. Je me tiens prΓͺt Γ  rΓ©pondre Γ  vos questions.</p>"
119
+ )
120
+ # input row
121
+ with gr.Row(equal_height=True):
122
+ msg = gr.Textbox(
123
+ placeholder="Posez votre question…",
124
+ show_label=False,
125
+ scale=5,
126
+ )
127
+ send = gr.Button("Envoyer", variant="primary", scale=1)
128
+
129
+ with gr.Column(scale=1):
130
+ sources = gr.Markdown("### πŸ“š Sources\n_Ici, vous pourrez consulter les sources utilisΓ©es pour formuler la rΓ©ponse._")
131
+
132
+ state = gr.State([]) # chat history: list[tuple(user, assistant)]
133
+
134
+ # wire events: user submits -> add_user -> bot streams
135
+ send_click = send.click(add_user, [msg, state], [msg, state])
136
+ send_click.then(
137
+ bot,
138
+ [state, api_key, topk],
139
+ [chat, sources],
140
+ show_progress="full",
141
+ ).then(lambda h: h, chat, state)
142
+
143
+ msg_submit = msg.submit(add_user, [msg, state], [msg, state])
144
+ msg_submit.then(
145
+ bot,
146
+ [state, api_key, topk],
147
+ [chat, sources],
148
+ show_progress="full",
149
+ ).then(lambda h: h, chat, state)
150
+
151
+
152
+ demo.load(_warmup, inputs=None, outputs=status)
153
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  if __name__ == "__main__":
156
  demo.queue().launch()
assets/chatbot.png ADDED

Git LFS Details

  • SHA256: 9daa93e27f8a3e5ea504737bebc879f7cd37a1895acfdc5ac5b092c9a7650e3e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
helpers.py CHANGED
@@ -1,11 +1,39 @@
1
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def linkify_text_with_sources(text: str, passages: list[dict]) -> str:
4
  """
5
- Convert [1], [2]… in `text` to markdown links using the corresponding
6
- passage payloads (expects top-5 `hits` from your retriever).
7
  """
8
- # Build mapping: 1-based index -> (title, url)
9
  mapping = {}
10
  for i, h in enumerate(passages, start=1):
11
  p = h.get("payload", h) or {}
@@ -17,9 +45,83 @@ def linkify_text_with_sources(text: str, passages: list[dict]) -> str:
17
  idx = int(m.group(1))
18
  title, url = mapping.get(idx, (None, None))
19
  if url:
20
- # turn [n] into [n](url "title")
21
- return f"[{idx}]({url} \"{title}\")"
22
- # leave as plain [n] if no URL
23
  return m.group(0)
24
 
25
- return re.sub(r"\[(\d+)\]", _sub, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
+ from collections import OrderedDict
3
+
4
+ CITATION_RE = re.compile(r"\[(\d+)\]")
5
+
6
+
7
+ def is_unknown_answer(txt: str) -> bool:
8
+ """Detect 'no answer' / 'reformulate' replies."""
9
+ s = (txt or "").lower()
10
+ patterns = [
11
+ "Je suis navrΓ©, je n'ai pas trouvΓ© la rΓ©ponse",
12
+ "Je ne sais pas",
13
+ "Je ne comprends pas la question"
14
+ "Pourriez-vous reformuler",
15
+ "je n'ai pas trouvΓ© d'information pertinente",
16
+ ]
17
+ return any(p in s for p in patterns)
18
+
19
+
20
+ def _extract_cited_indices(text: str, k: int) -> list[int]:
21
+ """Renvoie les indices (1..k) rΓ©ellement citΓ©s dans le texte, sans doublon, ordonnΓ©s."""
22
+ seen = OrderedDict()
23
+ for m in CITATION_RE.finditer(text or ""):
24
+ try:
25
+ n = int(m.group(1))
26
+ if 1 <= n <= k and n not in seen:
27
+ seen[n] = True
28
+ except Exception:
29
+ pass
30
+ return list(seen.keys())
31
 
32
  def linkify_text_with_sources(text: str, passages: list[dict]) -> str:
33
  """
34
+ Convertit [1], [2]… en vrais liens Markdown vers les sources.
 
35
  """
36
+ import re
37
  mapping = {}
38
  for i, h in enumerate(passages, start=1):
39
  p = h.get("payload", h) or {}
 
45
  idx = int(m.group(1))
46
  title, url = mapping.get(idx, (None, None))
47
  if url:
48
+ # simple lien markdown [1](url)
49
+ return f"[_[{idx}]_]({url} \"{title}\")"
 
50
  return m.group(0)
51
 
52
+ return re.sub(r"\[(\d+)\]", _sub, text)
53
+
54
+ def _group_sources_md(passages: list[dict], used_idxs: list[int]) -> str:
55
+ """
56
+ Construit le markdown groupΓ© :
57
+ ### πŸ“š Sources (N)
58
+ 1. [Titre](url) _(extrait #1, 3)_
59
+ 2. [Autre](url2) _(extrait #2)_
60
+ """
61
+ if not passages:
62
+ return "### πŸ“š Sources (0)\n_(aucune)_"
63
+
64
+ # Utiliser les indices citΓ©s si dispo, sinon tomber sur 1..len(passages)
65
+ if not used_idxs:
66
+ used_idxs = list(range(1, len(passages) + 1))
67
+
68
+ # Groupe par (url ou titre normalisΓ©)
69
+ groups = [] # [(key, title, url, [idxs])]
70
+ key_to_pos = {}
71
+
72
+ for idx in used_idxs:
73
+ p = passages[idx-1]
74
+ pl = p.get("payload", p) or {}
75
+ title = (pl.get("title") or pl.get("url") or f"Source {idx}").strip()
76
+ url = pl.get("url")
77
+
78
+ key = (url or "").strip().lower() or title.lower()
79
+ if key not in key_to_pos:
80
+ key_to_pos[key] = len(groups)
81
+ groups.append([key, title, url, []])
82
+ groups[key_to_pos[key]][3].append(idx)
83
+
84
+ # Ordonner chaque liste d'indices et construire le markdown
85
+ lines = [f"### πŸ“š Sources ({len(groups)})"] if len(groups) > 1 else [f"### πŸ“š Source"]
86
+ for i, (_, title, url, idxs) in enumerate(groups, start=1):
87
+ idxs = sorted(idxs)
88
+ idx_txt = ", ".join(map(str, idxs))
89
+ label = "extrait" if len(idxs) == 1 else "extraits"
90
+ suffix = f" _({label} # {idx_txt})_"
91
+ if url:
92
+ lines.append(f"{i}. [{title}]({url}){suffix}")
93
+ else:
94
+ lines.append(f"{i}. {title}{suffix}")
95
+ return "\n".join(lines)
96
+
97
+ # def sources_markdown(passages: list[dict]) -> str:
98
+ # if not passages:
99
+ # return "### Sources\n_(aucune)_"
100
+
101
+ # lines = [f"### πŸ“š Sources ({len(passages)})"]
102
+ # for i, h in enumerate(passages, start=1):
103
+ # p = h.get("payload", h) or {}
104
+ # title = (p.get("title") or p.get("url") or f"Source {i}").strip()
105
+ # url = p.get("url")
106
+ # score = h.get("score")
107
+ # # snippet = (p.get("text") or "").strip().replace("\n", " ")
108
+
109
+ # # # on coupe le snippet pour pas que ce soit trop long
110
+ # # if len(snippet) > 180:
111
+ # # snippet = snippet[:180] + "…"
112
+
113
+ # # ligne principale
114
+ # if url:
115
+ # line = f"{i}. [{title}]({url})"
116
+ # else:
117
+ # line = f"{i}. {title}"
118
+
119
+ # # on ajoute le score et snippet en italique, plus discrets
120
+ # if isinstance(score, (int, float)):
121
+ # line += f" _(score {score:.3f})_"
122
+ # # if snippet:
123
+ # # line += f"\n > {snippet}"
124
+
125
+ # lines.append(line)
126
+
127
+ # return "\n".join(lines)
rag/retrieval.py CHANGED
@@ -1,6 +1,6 @@
1
- import os, threading, ast
 
2
  from typing import List, Dict, Any, Optional, Tuple
3
-
4
  import numpy as np
5
  from datasets import load_dataset
6
  from huggingface_hub import InferenceClient
@@ -13,7 +13,6 @@ DATASETS = [
13
  HF_EMBED_MODEL = os.getenv("HF_EMBEDDINGS_MODEL", "BAAI/bge-m3")
14
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
15
 
16
- # Try FAISS; fallback to NumPy if import fails
17
  _USE_FAISS = True
18
  try:
19
  import faiss # type: ignore
@@ -21,8 +20,8 @@ except Exception:
21
  _USE_FAISS = False
22
 
23
  _embed_client: Optional[InferenceClient] = None
24
- _index = None # faiss index or np.ndarray
25
- _payloads = None # list[dict]
26
  _lock = threading.Lock()
27
 
28
  def _client() -> InferenceClient:
@@ -61,15 +60,28 @@ def _load_corpus() -> Tuple[np.ndarray, List[Dict[str, Any]]]:
61
  X = np.stack(vecs, axis=0)
62
  return X, payloads
63
 
 
 
 
 
 
 
64
  def _build_index():
 
 
 
 
 
 
 
 
 
 
 
 
65
  X, payloads = _load_corpus()
66
- if _USE_FAISS:
67
- dim = X.shape[1]
68
- idx = faiss.IndexFlatIP(dim)
69
- idx.add(X)
70
- return idx, payloads
71
- else:
72
- return X, payloads # NumPy fallback
73
 
74
  def _ensure():
75
  global _index, _payloads
@@ -80,24 +92,11 @@ def _ensure():
80
 
81
  def _search_numpy(X: np.ndarray, q: np.ndarray, k: int):
82
  scores = X @ q
83
- k = min(k, len(scores))
84
  part = np.argpartition(-scores, k-1)[:k]
85
  order = part[np.argsort(-scores[part])]
86
  return scores[order], order
87
 
88
- def rerank_cosine(query_vec, hits, top_k=5):
89
- # Re-embed candidate texts and compare? (expensive)
90
- # or use retrieval scores only (already cosine). If using NumPy fallback,
91
- # you can keep as is. For a tiny boost, score by length-normalized match:
92
- scored = []
93
- for h in hits:
94
- txt = (h["payload"].get("text") or "")
95
- # penalize super-long chunks a bit
96
- penalty = 1.0 / (1.0 + len(txt)/1500.0)
97
- scored.append((h["score"] * penalty, h))
98
- scored.sort(key=lambda x: x[0], reverse=True)
99
- return [h for _, h in scored[:top_k]]
100
-
101
  def search(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
102
  _ensure()
103
  q = embed(query)
@@ -113,3 +112,18 @@ def search(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
113
  hits.append({"score": float(s), "payload": p})
114
  return hits
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retrieval.py
2
+ import os, ast, threading
3
  from typing import List, Dict, Any, Optional, Tuple
 
4
  import numpy as np
5
  from datasets import load_dataset
6
  from huggingface_hub import InferenceClient
 
13
  HF_EMBED_MODEL = os.getenv("HF_EMBEDDINGS_MODEL", "BAAI/bge-m3")
14
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
15
 
 
16
  _USE_FAISS = True
17
  try:
18
  import faiss # type: ignore
 
20
  _USE_FAISS = False
21
 
22
  _embed_client: Optional[InferenceClient] = None
23
+ _index = None
24
+ _payloads = None
25
  _lock = threading.Lock()
26
 
27
  def _client() -> InferenceClient:
 
60
  X = np.stack(vecs, axis=0)
61
  return X, payloads
62
 
63
+ CACHE = "/tmp/rag_index.npz"
64
+
65
+ def _faiss_from_X(X):
66
+ import faiss
67
+ idx = faiss.IndexFlatIP(X.shape[1]); idx.add(X); return idx
68
+
69
  def _build_index():
70
+ # try cached index
71
+ if os.path.exists(CACHE):
72
+ try:
73
+ d = np.load(CACHE, allow_pickle=True)
74
+ X = d["X"]
75
+ payloads = d["payloads"].tolist()
76
+ return (_faiss_from_X(X) if _USE_FAISS else X), payloads
77
+ except Exception:
78
+ # cache corrupted β†’ rebuild
79
+ try: os.remove(CACHE)
80
+ except: pass
81
+ # build fresh
82
  X, payloads = _load_corpus()
83
+ np.savez_compressed(CACHE, X=X, payloads=np.array(payloads, dtype=object))
84
+ return (_faiss_from_X(X) if _USE_FAISS else X), payloads
 
 
 
 
 
85
 
86
  def _ensure():
87
  global _index, _payloads
 
92
 
93
  def _search_numpy(X: np.ndarray, q: np.ndarray, k: int):
94
  scores = X @ q
95
+ k = max(1, min(k, len(scores)))
96
  part = np.argpartition(-scores, k-1)[:k]
97
  order = part[np.argsort(-scores[part])]
98
  return scores[order], order
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def search(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
101
  _ensure()
102
  q = embed(query)
 
112
  hits.append({"score": float(s), "payload": p})
113
  return hits
114
 
115
+ # ---------- explicit warm-up helpers ----------
116
+ def warm_up_sync():
117
+ try:
118
+ _ = search("warmup", top_k=3)
119
+ except Exception:
120
+ pass
121
+
122
+ def warm_up_async():
123
+ t = threading.Thread(target=warm_up_sync, daemon=True)
124
+ t.start()
125
+
126
+ def ensure_ready():
127
+ """Build the index once and warm the embedding endpoint."""
128
+ _ensure() # builds FAISS/NumPy index + loads payloads
129
+ _ = embed("warmup") # hits HF Inference API once to avoid cold-start
rag/synth.py CHANGED
@@ -1,157 +1,58 @@
 
1
  import os
2
  from openai import OpenAI
 
3
 
4
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
5
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
6
 
7
  def _build_prompt(query, passages):
8
- ctx = "\n\n".join([(p["payload"].get("text") or "") for p in passages])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return (
10
- "Tu es un assistant RH de la fonction publique franΓ§aise.\n"
11
- "- RΓ©ponds de faΓ§on factuelle et concise.\n"
12
- "- Cite les sources en fin de phrase avec [1], [2]… basΓ©es sur l’ordre des passages.\n"
13
- "- Si l’info n’est pas dans les sources, rΓ©ponds Β« Je ne sais pas Β».\n\n"
14
- f"Question: {query}\n\nSources (indexΓ©es):\n{ctx}\n\nRΓ©ponse:"
 
 
 
 
 
15
  )
16
 
17
  def synth_answer_stream(query, passages):
18
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
19
- prompt = _build_prompt(query, passages)
 
20
  stream = client.chat.completions.create(
21
  model=LLM_MODEL,
22
  messages=[{"role": "user", "content": prompt}],
23
  temperature=0.2,
24
- stream=True, # πŸ‘ˆ IMPORTANT
25
  )
26
- # The SDK yields events with deltas
27
- for event in stream:
28
- delta = getattr(getattr(event, "choices", [None])[0], "delta", None)
29
- if delta and delta.content:
30
- yield delta.content
31
-
32
- # def linkify(text, passages):
33
- # # (optional) keep simple: return text as-is for now
34
- # return text
35
-
36
- def render_sources(passages):
37
- lines = []
38
- for i, p in enumerate(passages, 1):
39
- title = (p["payload"].get("title") or "").strip() or "Sans titre"
40
- url = p["payload"].get("url") or ""
41
- lines.append(f"[{i}] {title}{' – ' + url if url else ''}")
42
- return "\n".join(lines)
43
-
44
- # def linkify_text_with_sources(text: str, passages):
45
- # """
46
- # Replace [1], [2]... with clickable links if the passage has a URL.
47
- # Also append a Sources section as a numbered list.
48
- # """
49
- # # Build a map: 1-based index -> url
50
- # urls = []
51
- # for p in passages:
52
- # url = (p["payload"].get("url") or "").strip()
53
- # urls.append(url if url.startswith("http") else "")
54
-
55
- # # Inline [n] -> [n](url) when available
56
- # out = text
57
- # for i, url in enumerate(urls, start=1):
58
- # if url:
59
- # out = out.replace(f"[{i}]", f"[{i}]({url})")
60
-
61
- # # Add a Sources section
62
- # lines = ["\n\n---\n**Sources**"]
63
- # for i, p in enumerate(passages, start=1):
64
- # title = (p["payload"].get("title") or "").strip() or "Sans titre"
65
- # url = (p["payload"].get("url") or "").strip()
66
- # if url.startswith("http"):
67
- # lines.append(f"{i}. [{title}]({url})")
68
- # else:
69
- # lines.append(f"{i}. {title}")
70
- # return out + "\n" + "\n".join(lines)
71
- # import os
72
- # from openai import OpenAI
73
-
74
- # LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
75
- # LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
76
-
77
- # def _first_k_chars(text, k=1200):
78
- # t = text.strip()
79
- # return t[:k] + ("…" if len(t) > k else "")
80
-
81
- # def _build_prompt(query, passages):
82
- # chunks = []
83
- # for i, p in enumerate(passages, 1):
84
- # txt = p["payload"].get("text") or ""
85
- # chunks.append(f"[{i}] {_first_k_chars(txt)}")
86
-
87
- # # def _build_prompt(query, passages):
88
- # # chunks = []
89
- # # for i, p in enumerate(passages, 1):
90
- # # txt = p["payload"].get("text") or ""
91
- # # chunks.append(f"[{i}] {txt}")
92
- # context = "\n\n".join(chunks)
93
-
94
- # return f"""Tu es un assistant RH de la fonction publique franΓ§aise.
95
- # - Réponds de manière factuelle et concise.
96
- # - Cite tes sources en fin de phrase avec [n] correspondant aux extraits ci-dessous.
97
- # - Si l’information n’est pas dans les sources, rΓ©ponds : β€œJe ne sais pas”.
98
- # - Ne fabrique pas de liens ni de rΓ©fΓ©rences.
99
-
100
- # Question: {query}
101
-
102
- # Extraits indexΓ©s:
103
- # {context}
104
-
105
- # RΓ©ponse:"""
106
-
107
- # def synth_answer_stream(query, passages):
108
- # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
109
- # prompt = _build_prompt(query, passages)
110
-
111
- # # βœ… Correct streaming usage
112
- # stream = client.chat.completions.create(
113
- # model=LLM_MODEL,
114
- # messages=[{"role": "user", "content": prompt}],
115
- # temperature=0.2,
116
- # stream=True, # <- this is key
117
- # )
118
- # for chunk in stream:
119
- # delta = getattr(chunk.choices[0].delta, "content", None)
120
- # if delta:
121
- # acc.append(delta)
122
- # yield delta # stream piece by piece
123
- # # def synth_answer(query, passages):
124
- # # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
125
- # # prompt = _build_prompt(query, passages)
126
-
127
- # # resp = client.chat.completions.create(
128
- # # model=LLM_MODEL,
129
- # # messages=[{"role": "user", "content": prompt}],
130
- # # temperature=0.2,
131
- # # )
132
- # # return resp.choices[0].message.content.strip()
133
-
134
- # # --- HELPERS
135
-
136
- # def render_sources(passages):
137
- # lines = []
138
- # for i, p in enumerate(passages, 1):
139
- # pl = p["payload"]
140
- # title = (pl.get("title") or "Source").strip()
141
- # url = pl.get("url") or ""
142
- # lines.append(f"[{i}] {title}" + (f" β€” {url}" if url else ""))
143
- # return "\n".join(lines)
144
-
145
- # def linkify(text, passages):
146
- # # turn [1] -> markdown link when url exists
147
- # for i, p in enumerate(passages, 1):
148
- # url = p["payload"].get("url")
149
- # if url:
150
- # text = text.replace(f"[{i}]", f"[{i}]({url})")
151
- # return text
152
-
153
-
154
-
155
-
156
-
157
 
 
 
 
 
 
 
 
1
+ # rag/synth.py
2
  import os
3
  from openai import OpenAI
4
+ from rag.utils import utf8_safe
5
 
6
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
7
  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
8
 
9
  def _build_prompt(query, passages):
10
+ from rag.utils import utf8_safe
11
+
12
+ # Construire des blocs numΓ©rotΓ©s et balisΓ©s
13
+ blocks = []
14
+ for i, h in enumerate(passages, start=1):
15
+ p = h.get("payload", h) or {}
16
+ title = (p.get("title") or p.get("url") or f"Source {i}").strip()
17
+ url = p.get("url") or ""
18
+ text = utf8_safe(p.get("text") or "")
19
+ # Chaque bloc porte explicitement son index [i]
20
+ blocks.append(
21
+ f"### Source [{i}] β€” {title}\n"
22
+ f"{('URL: ' + url) if url else ''}\n"
23
+ f"{text}\n"
24
+ )
25
+
26
+ context = "\n\n".join(blocks)
27
+ query = utf8_safe(query)
28
+
29
  return (
30
+ "Tu es un assistant RH chargΓ© de rΓ©pondre Γ  des questions dans le domaine des ressources humaines en t'appuyant sur les sources fournies.\n"
31
+ "Consignes :\n"
32
+ "- Réponds de manière factuelle, concise et polie (vouvoiement).\n"
33
+ "- Quand tu affirmes un fait, cite tes sources en fin de phrase avec le format [1], [2]… en te basant sur l'index de ces sources (ex: [1] est la source 1, [2] est la source 2, etc.)\n\n"
34
+ "- Si l'information n'est pas prΓ©sente dans les sources, rΓ©ponds : \"Je suis navrΓ©, je n'ai pas trouvΓ© la rΓ©ponse Γ  cette question\".\n\n"
35
+ "- Si la question est mal formulΓ©e, rΓ©ponds : \"Je ne comprends pas la question. Pourriez-vous reformuler ?\"\n\n"
36
+ "- Ne fabrique pas de liens ni de rΓ©fΓ©rences.\n\n"
37
+ f"Question: {query}\n"
38
+ f"Sources (indexΓ©es) : {context}\n\n"
39
+ "RΓ©ponse:"
40
  )
41
 
42
  def synth_answer_stream(query, passages):
43
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
44
+ prompt = utf8_safe(_build_prompt(query, passages))
45
+
46
  stream = client.chat.completions.create(
47
  model=LLM_MODEL,
48
  messages=[{"role": "user", "content": prompt}],
49
  temperature=0.2,
50
+ stream=True,
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ for event in stream:
54
+ if not getattr(event, "choices", None):
55
+ continue
56
+ delta = event.choices[0].delta
57
+ if delta and getattr(delta, "content", None):
58
+ yield utf8_safe(delta.content or "")
rag/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/utils.py
2
+ import unicodedata
3
+
4
+ def utf8_safe(s: str) -> str:
5
+ if not isinstance(s, str):
6
+ s = str(s)
7
+ # normalise et remplace le tiret cadratin par un simple '-'
8
+ s = unicodedata.normalize("NFC", s)
9
+ s = s.replace("\u2014", "-")
10
+ # Si une lib force l'ASCII en dessous, on garde quand mΓͺme tout ce qui est encodable UTF-8
11
+ return s.encode("utf-8", "ignore").decode("utf-8", "ignore")