edouardfoussier commited on
Commit
525a9ab
·
1 Parent(s): adf9604

gradio chat app fonctionne - streaming

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. .gitignore +10 -0
  3. app.py +141 -210
  4. helpers.py +25 -0
  5. rag/__init__.py +0 -0
  6. rag/retrieval.py +115 -0
  7. rag/synth.py +157 -0
  8. requirements.txt +2 -1
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ assets/chatbot.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .venv
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ *.pyw
8
+ *.pyz
9
+ *.pywz
10
+ *.pyzw
app.py CHANGED
@@ -1,216 +1,147 @@
1
- import os, ast, threading
2
- from typing import List, Dict, Any, Optional, Tuple
 
 
 
3
 
4
  import gradio as gr
5
- import numpy as np
6
- from datasets import load_dataset
7
- from huggingface_hub import InferenceClient
8
-
9
- # -------------------------------
10
- # Config
11
- # -------------------------------
12
- EMBED_COL = os.getenv("EMBED_COL", "embeddings_bge-m3")
13
- DATASETS = [
14
- ("edouardfoussier/travail-emploi-clean", "train"),
15
- ("edouardfoussier/service-public-filtered", "train"),
16
- ]
17
-
18
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
19
- HF_EMBED_MODEL = os.getenv("HF_EMBED_MODEL", "BAAI/bge-m3")
20
- HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
21
-
22
- if not HF_API_TOKEN:
23
- raise RuntimeError("HF_API_TOKEN not set. Add it in Space → Settings → Variables.")
24
-
25
- # Try FAISS; fallback to NumPy if not available
26
- _USE_FAISS = True
27
- try:
28
- import faiss # type: ignore
29
- except Exception:
30
- _USE_FAISS = False
31
-
32
- # -------------------------------
33
- # Globals
34
- # -------------------------------
35
- _embed_client: Optional[InferenceClient] = None
36
- _gen_client: Optional[InferenceClient] = None
37
-
38
- _index = None # FAISS index or dense matrix (NumPy)
39
- _payloads = None # list[dict]
40
- _dim = None
41
- _lock = threading.Lock()
42
-
43
- def _get_embed_client() -> InferenceClient:
44
- global _embed_client
45
- if _embed_client is None:
46
- _embed_client = InferenceClient(token=HF_API_TOKEN)
47
- return _embed_client
48
-
49
- def _get_gen_client() -> InferenceClient:
50
- global _gen_client
51
- if _gen_client is None:
52
- _gen_client = InferenceClient(token=HF_API_TOKEN)
53
- return _gen_client
54
-
55
- def _to_vec(x):
56
- if isinstance(x, list):
57
- return np.asarray(x, dtype=np.float32)
58
- if isinstance(x, str):
59
- return np.asarray(ast.literal_eval(x), dtype=np.float32)
60
- raise TypeError(f"Unsupported embedding type: {type(x)}")
61
-
62
- def _normalize(v: np.ndarray) -> np.ndarray:
63
- v = v.astype(np.float32, copy=False)
64
- n = np.linalg.norm(v) + 1e-12
65
- return v / n
66
-
67
- def _embed_query(text: str) -> np.ndarray:
68
- # HF feature-extraction
69
- vec = _get_embed_client().feature_extraction(text, model=HF_EMBED_MODEL)
70
- v = np.asarray(vec, dtype=np.float32)
71
- if v.ndim == 2:
72
- v = v[0]
73
- return _normalize(v)
74
-
75
- def _load_datasets() -> Tuple[np.ndarray, List[Dict[str, Any]]]:
76
- vecs, payloads = [], []
77
- for name, split in DATASETS:
78
- ds = load_dataset(name, split=split)
79
- for row in ds:
80
- v = _normalize(_to_vec(row[EMBED_COL]))
81
- vecs.append(v)
82
- p = dict(row); p.pop(EMBED_COL, None)
83
- payloads.append(p)
84
- X = np.stack(vecs, axis=0)
85
- return X, payloads
86
-
87
- def _build_index() -> Tuple[Any, List[Dict[str, Any]], int]:
88
- X, payloads = _load_datasets()
89
- dim = X.shape[1]
90
- if _USE_FAISS:
91
- idx = faiss.IndexFlatIP(dim)
92
- idx.add(X)
93
- else:
94
- idx = X # NumPy matrix
95
- return idx, payloads, dim
96
-
97
- def _ensure_index():
98
- global _index, _payloads, _dim
99
- if _index is not None:
100
  return
101
- with _lock:
102
- if _index is None:
103
- _index, _payloads, _dim = _build_index()
104
-
105
- def _search_numpy(X: np.ndarray, q: np.ndarray, k: int):
106
- scores = X @ q # cosine/IP (normalized)
107
- k = min(k, len(scores))
108
- part = np.argpartition(-scores, k-1)[:k]
109
- order = part[np.argsort(-scores[part])]
110
- return scores[order], order
111
-
112
- def retrieve(query: str, top_k: int = 6) -> List[Dict[str, Any]]:
113
- _ensure_index()
114
- q = _embed_query(query)
115
- if _USE_FAISS:
116
- D, I = _index.search(q[None, :], top_k)
117
- scores, idxs = D[0], I[0]
118
- else:
119
- scores, idxs = _search_numpy(_index, q, top_k)
120
- out = []
121
- for idx, sc in zip(idxs, scores):
122
- if idx == -1:
123
- continue
124
- p = _payloads[int(idx)]
125
- out.append({"score": float(sc), "payload": p})
126
- return out
127
-
128
- def build_prompt(query: str, passages: List[Dict[str, Any]]) -> str:
129
- chunks = []
130
- for i, h in enumerate(passages, 1):
131
- p = h["payload"]
132
- text = p.get("text") or p.get("chunk_text") or ""
133
- source = p.get("source") or "unknown"
134
- title = p.get("title") or ""
135
- url = p.get("url") or ""
136
- chunks.append(f"[{i}] ({source}) {title}\n{text}\nURL: {url}\n")
137
- context = "\n\n".join(chunks)
138
- return f"""You are a helpful HR assistant. Answer the question strictly using the CONTEXT.
139
- If the CONTEXT is not enough, say you don't know.
140
-
141
- QUESTION:
142
- {query}
143
-
144
- CONTEXT:
145
- {context}
146
-
147
- Answer in French. Cite sources inline like [1], [2] where relevant.
148
- """
149
-
150
- def stream_llm(prompt: str):
151
- # Stream tokens from HF Inference API text generation
152
- client = _get_gen_client()
153
- # temperature/params small so result is stable
154
- stream = client.text_generation(
155
- model=HF_LLM_MODEL,
156
- prompt=prompt,
157
- max_new_tokens=512,
158
- temperature=0.2,
159
- top_p=0.9,
160
- stream=True,
161
- stop=None,
162
- )
163
- for chunk in stream:
164
- # chunk is a string token or piece; just yield it
165
- yield chunk
166
-
167
- def format_sources(passages: List[Dict[str, Any]]) -> str:
168
- lines = []
169
- for i, h in enumerate(passages, 1):
170
- p = h["payload"]
171
- title = (p.get("title") or "").strip() or "(Sans titre)"
172
- url = p.get("url") or ""
173
- src = p.get("source") or "unknown"
174
- lines.append(f"[{i}] **{title}** — _{src}_ " + (f"[lien]({url})" if url else ""))
175
- return "\n".join(lines)
176
-
177
- # -------------------------------
178
- # Gradio Chat handler
179
- # -------------------------------
180
- def respond(message, history):
181
- # Retrieve
182
- passages = retrieve(message, top_k=6)
183
- prompt = build_prompt(message, passages)
184
-
185
- # Stream answer
186
- answer_so_far = ""
187
- for token in stream_llm(prompt):
188
- answer_so_far += token
189
- yield answer_so_far
190
-
191
- # Append sources as an expandable block (return another message)
192
- sources_md = format_sources(passages)
193
- yield answer_so_far + "\n\n---\n**Sources**\n" + sources_md
194
-
195
- with gr.Blocks(fill_height=True) as demo:
196
- gr.Markdown("## 🔎 Assistant RH — RAG Chatbot")
197
- gr.Markdown(
198
- f"**Embeddings:** `{HF_EMBED_MODEL}`   |   **LLM:** `{HF_LLM_MODEL}`"
199
- )
200
- chat = gr.ChatInterface(
201
- fn=respond,
202
- type="messages",
203
- title="Assistant RH",
204
- examples=[
205
- "Quels sont les droits à congés pour un agent contractuel ?",
206
- "Comment déclarer l’embauche d’un salarié (DPAE) ?",
207
- "Quelles sont les obligations de l’employeur pour le télétravail ?",
208
- ],
209
- retry_btn="Reformuler",
210
- undo_btn=None,
211
- clear_btn="Effacer",
212
- description="Posez une question RH. Réponse générée avec récupération documentaire.",
213
  )
 
214
 
215
  if __name__ == "__main__":
216
- demo.queue(concurrency_count=2).launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os, time
2
+ from dotenv import load_dotenv
3
+
4
+ # Load environment variables BEFORE importing rag modules
5
+ load_dotenv(override=True)
6
 
7
  import gradio as gr
8
+ from rag.retrieval import search, embed
9
+ from rag.synth import synth_answer_stream, render_sources
10
+ from helpers import linkify_text_with_sources
11
+
12
+ missing = []
13
+ if not os.getenv("HF_API_TOKEN"): missing.append("HF_API_TOKEN (embeddings)")
14
+ if not os.getenv("LLM_MODEL"): print("[INFO] LLM_MODEL not set, using default", flush=True)
15
+ print("[ENV] Missing:", ", ".join(missing) or "None", flush=True)
16
+ # HF_API_TOKEN = os.getenv("HF_API_TOKEN")
17
+
18
+ # def sanity():
19
+ # ok = bool(os.getenv("HF_API_TOKEN"))
20
+ # v = embed("hello world")
21
+ # return f"Token set? {ok}\nEmbedding dim: {len(v)}"
22
+
23
+ # def rag_chat(user_question, openai_key):
24
+ # if not openai_key:
25
+ # return "❌ Please provide your OpenAI API key."
26
+
27
+ # # Inject the key into environment so synth can use it
28
+ # os.environ["OPENAI_API_KEY"] = openai_key
29
+
30
+ # # Step 1: Retrieve top passages
31
+ # hits = search(user_question, top_k=8)
32
+
33
+ # if not hits:
34
+ # return "❌ Sorry, no relevant information found."
35
+
36
+ # # Step 2: Generate synthesized answer
37
+ # try:
38
+ # final_answer = synth_answer(user_question, hits[:5])
39
+ # final_answer = linkify(final_answer, hits[:5])
40
+ # final_answer += "\n\n---\n" + render_sources(hits[:5])
41
+ # except Exception as e:
42
+ # final_answer = f"❌ Error during synthesis: {e}"
43
+
44
+ # return final_answer
45
+ # def rag_chat(user_question, openai_key):
46
+ # if not openai_key:
47
+ # yield "❌ Please provide your OpenAI API key."
48
+ # return
49
+
50
+ # os.environ["OPENAI_API_KEY"] = openai_key
51
+
52
+ # hits = search(user_question, top_k=8)
53
+ # if not hits:
54
+ # yield "❌ Sorry, no relevant information found."
55
+ # return
56
+
57
+ # acc = ""
58
+ # try:
59
+ # for piece in synth_answer_stream(user_question, hits[:5]):
60
+ # acc += piece or ""
61
+ # # stream raw text while typing (no links yet to avoid jumpiness)
62
+ # yield acc
63
+ # except Exception as e:
64
+ # partial = acc if acc.strip() else ""
65
+ # yield (partial + ("\n\n" if partial else "") + f"❌ Streaming error: {e}")
66
+ # return
67
+
68
+ # final_md = linkify_text_with_sources(acc, hits[:5])
69
+ # yield final_md
70
+
71
+
72
+
73
+ # with gr.Blocks() as demo:
74
+ # gr.Markdown("## 🤖 HR Assistant (RAG)\nAsk your question below:")
75
+
76
+ # with gr.Row():
77
+ # api_key = gr.Textbox(label="🔑 Your OpenAI API Key", type="password")
78
+
79
+ # question = gr.Textbox(label="❓ Your Question", placeholder="e.g., Quels sont les droits à congés ?")
80
+
81
+ # answer = gr.Markdown(label="💡 Assistant Answer")
82
+
83
+ # submit_btn = gr.Button("Ask")
84
+
85
+ # submit_btn.click(fn=rag_chat, inputs=[question, api_key], outputs=answer)
86
+
87
+
88
+ # if __name__ == "__main__":
89
+ # demo.launch()
90
+
91
+
92
+ def rag_chat(user_question: str, openai_key: str):
93
+ """Generator: streams draft text to a Textbox, then yields final Markdown."""
94
+ if not openai_key:
95
+ yield "❌ Please provide your OpenAI API key.", None
 
 
 
 
 
 
 
96
  return
97
+
98
+ os.environ["OPENAI_API_KEY"] = openai_key.strip()
99
+
100
+ # Step 1: retrieve
101
+ yield "⏳ Recherche des passages pertinents…", None
102
+ hits = search(user_question, top_k=8)
103
+ if not hits:
104
+ yield "❌ Sorry, no relevant information found.", None
105
+ return
106
+
107
+ # Step 2: stream LLM synthesis
108
+ acc = ""
109
+ try:
110
+ for piece in synth_answer_stream(user_question, hits[:5]):
111
+ acc += piece or ""
112
+ # Stream into the draft textbox; keep markdown empty during typing
113
+ yield acc, None
114
+ except Exception as e:
115
+ yield f"❌ Error during synthesis: {e}", None
116
+ return
117
+
118
+ # Step 3: finalize + linkify citations in Markdown block
119
+ md = linkify_text_with_sources(acc, hits[:5])
120
+ yield acc, md
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("## 🤖 HR Assistant (RAG)\nAsk your question below:")
124
+
125
+ with gr.Row():
126
+ api_key = gr.Textbox(label="🔑 Your OpenAI API Key", type="password", placeholder="sk-…")
127
+ question = gr.Textbox(label="❓ Your Question", placeholder="e.g., Quels sont les droits à congés ?")
128
+
129
+ # live streaming target
130
+ draft_answer = gr.Markdown(label="💬 Réponse")
131
+ # final pretty markdown with clickable links
132
+ # final_answer = gr.Markdown()
133
+
134
+ with gr.Row():
135
+ submit_btn = gr.Button("Ask", variant="primary")
136
+ clear_btn = gr.Button("Clear")
137
+
138
+ submit_btn.click(
139
+ fn=rag_chat,
140
+ inputs=[question, api_key],
141
+ outputs=[draft_answer, final_answer],
142
+ show_progress="full", # shows loader on the button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  )
144
+ clear_btn.click(lambda: ("", ""), outputs=[draft_answer, final_answer])
145
 
146
  if __name__ == "__main__":
147
+ demo.queue().launch()
helpers.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def linkify_text_with_sources(text: str, passages: list[dict]) -> str:
4
+ """
5
+ Convert [1], [2]… in `text` to markdown links using the corresponding
6
+ passage payloads (expects top-5 `hits` from your retriever).
7
+ """
8
+ # Build mapping: 1-based index -> (title, url)
9
+ mapping = {}
10
+ for i, h in enumerate(passages, start=1):
11
+ p = h.get("payload", h) or {}
12
+ title = p.get("title") or p.get("url") or f"Source {i}"
13
+ url = p.get("url")
14
+ mapping[i] = (title, url)
15
+
16
+ def _sub(m):
17
+ idx = int(m.group(1))
18
+ title, url = mapping.get(idx, (None, None))
19
+ if url:
20
+ # turn [n] into [n](url "title")
21
+ return f"[{idx}]({url} \"{title}\")"
22
+ # leave as plain [n] if no URL
23
+ return m.group(0)
24
+
25
+ return re.sub(r"\[(\d+)\]", _sub, text)
rag/__init__.py ADDED
File without changes
rag/retrieval.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, threading, ast
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from huggingface_hub import InferenceClient
7
+
8
+ EMBED_COL = os.getenv("EMBED_COL", "embeddings_bge-m3")
9
+ DATASETS = [
10
+ ("edouardfoussier/travail-emploi-clean", "train"),
11
+ ("edouardfoussier/service-public-filtered", "train"),
12
+ ]
13
+ HF_EMBED_MODEL = os.getenv("HF_EMBEDDINGS_MODEL", "BAAI/bge-m3")
14
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
15
+
16
+ # Try FAISS; fallback to NumPy if import fails
17
+ _USE_FAISS = True
18
+ try:
19
+ import faiss # type: ignore
20
+ except Exception:
21
+ _USE_FAISS = False
22
+
23
+ _embed_client: Optional[InferenceClient] = None
24
+ _index = None # faiss index or np.ndarray
25
+ _payloads = None # list[dict]
26
+ _lock = threading.Lock()
27
+
28
+ def _client() -> InferenceClient:
29
+ global _embed_client
30
+ if _embed_client is None:
31
+ if not HF_API_TOKEN:
32
+ raise RuntimeError("HF_API_TOKEN missing (.env)")
33
+ _embed_client = InferenceClient(model=HF_EMBED_MODEL, token=HF_API_TOKEN)
34
+ return _embed_client
35
+
36
+ def _to_vec(x):
37
+ if isinstance(x, list): return np.asarray(x, dtype=np.float32)
38
+ if isinstance(x, str): return np.asarray(ast.literal_eval(x), dtype=np.float32)
39
+ raise TypeError(f"Bad embedding type: {type(x)}")
40
+
41
+ def _norm(v: np.ndarray) -> np.ndarray:
42
+ v = v.astype(np.float32, copy=False)
43
+ n = np.linalg.norm(v) + 1e-12
44
+ return v / n
45
+
46
+ def embed(text: str) -> np.ndarray:
47
+ vec = _client().feature_extraction(text)
48
+ v = np.asarray(vec, dtype=np.float32)
49
+ if v.ndim == 2: v = v[0]
50
+ return _norm(v)
51
+
52
+ def _load_corpus() -> Tuple[np.ndarray, List[Dict[str, Any]]]:
53
+ vecs, payloads = [], []
54
+ for name, split in DATASETS:
55
+ ds = load_dataset(name, split=split)
56
+ for row in ds:
57
+ v = _norm(_to_vec(row[EMBED_COL]))
58
+ vecs.append(v)
59
+ p = dict(row); p.pop(EMBED_COL, None)
60
+ payloads.append(p)
61
+ X = np.stack(vecs, axis=0)
62
+ return X, payloads
63
+
64
+ def _build_index():
65
+ X, payloads = _load_corpus()
66
+ if _USE_FAISS:
67
+ dim = X.shape[1]
68
+ idx = faiss.IndexFlatIP(dim)
69
+ idx.add(X)
70
+ return idx, payloads
71
+ else:
72
+ return X, payloads # NumPy fallback
73
+
74
+ def _ensure():
75
+ global _index, _payloads
76
+ if _index is not None: return
77
+ with _lock:
78
+ if _index is None:
79
+ _index, _payloads = _build_index()
80
+
81
+ def _search_numpy(X: np.ndarray, q: np.ndarray, k: int):
82
+ scores = X @ q
83
+ k = min(k, len(scores))
84
+ part = np.argpartition(-scores, k-1)[:k]
85
+ order = part[np.argsort(-scores[part])]
86
+ return scores[order], order
87
+
88
+ def rerank_cosine(query_vec, hits, top_k=5):
89
+ # Re-embed candidate texts and compare? (expensive)
90
+ # or use retrieval scores only (already cosine). If using NumPy fallback,
91
+ # you can keep as is. For a tiny boost, score by length-normalized match:
92
+ scored = []
93
+ for h in hits:
94
+ txt = (h["payload"].get("text") or "")
95
+ # penalize super-long chunks a bit
96
+ penalty = 1.0 / (1.0 + len(txt)/1500.0)
97
+ scored.append((h["score"] * penalty, h))
98
+ scored.sort(key=lambda x: x[0], reverse=True)
99
+ return [h for _, h in scored[:top_k]]
100
+
101
+ def search(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
102
+ _ensure()
103
+ q = embed(query)
104
+ if _USE_FAISS:
105
+ D, I = _index.search(q[None, :], top_k)
106
+ scores, idxs = D[0], I[0]
107
+ else:
108
+ scores, idxs = _search_numpy(_index, q, top_k)
109
+ hits = []
110
+ for i, s in zip(idxs, scores):
111
+ if i == -1: continue
112
+ p = _payloads[int(i)]
113
+ hits.append({"score": float(s), "payload": p})
114
+ return hits
115
+
rag/synth.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+
4
+ LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
5
+ LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
6
+
7
+ def _build_prompt(query, passages):
8
+ ctx = "\n\n".join([(p["payload"].get("text") or "") for p in passages])
9
+ return (
10
+ "Tu es un assistant RH de la fonction publique française.\n"
11
+ "- Réponds de façon factuelle et concise.\n"
12
+ "- Cite les sources en fin de phrase avec [1], [2]… basées sur l’ordre des passages.\n"
13
+ "- Si l’info n’est pas dans les sources, réponds « Je ne sais pas ».\n\n"
14
+ f"Question: {query}\n\nSources (indexées):\n{ctx}\n\nRéponse:"
15
+ )
16
+
17
+ def synth_answer_stream(query, passages):
18
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
19
+ prompt = _build_prompt(query, passages)
20
+ stream = client.chat.completions.create(
21
+ model=LLM_MODEL,
22
+ messages=[{"role": "user", "content": prompt}],
23
+ temperature=0.2,
24
+ stream=True, # 👈 IMPORTANT
25
+ )
26
+ # The SDK yields events with deltas
27
+ for event in stream:
28
+ delta = getattr(getattr(event, "choices", [None])[0], "delta", None)
29
+ if delta and delta.content:
30
+ yield delta.content
31
+
32
+ # def linkify(text, passages):
33
+ # # (optional) keep simple: return text as-is for now
34
+ # return text
35
+
36
+ def render_sources(passages):
37
+ lines = []
38
+ for i, p in enumerate(passages, 1):
39
+ title = (p["payload"].get("title") or "").strip() or "Sans titre"
40
+ url = p["payload"].get("url") or ""
41
+ lines.append(f"[{i}] {title}{' – ' + url if url else ''}")
42
+ return "\n".join(lines)
43
+
44
+ # def linkify_text_with_sources(text: str, passages):
45
+ # """
46
+ # Replace [1], [2]... with clickable links if the passage has a URL.
47
+ # Also append a Sources section as a numbered list.
48
+ # """
49
+ # # Build a map: 1-based index -> url
50
+ # urls = []
51
+ # for p in passages:
52
+ # url = (p["payload"].get("url") or "").strip()
53
+ # urls.append(url if url.startswith("http") else "")
54
+
55
+ # # Inline [n] -> [n](url) when available
56
+ # out = text
57
+ # for i, url in enumerate(urls, start=1):
58
+ # if url:
59
+ # out = out.replace(f"[{i}]", f"[{i}]({url})")
60
+
61
+ # # Add a Sources section
62
+ # lines = ["\n\n---\n**Sources**"]
63
+ # for i, p in enumerate(passages, start=1):
64
+ # title = (p["payload"].get("title") or "").strip() or "Sans titre"
65
+ # url = (p["payload"].get("url") or "").strip()
66
+ # if url.startswith("http"):
67
+ # lines.append(f"{i}. [{title}]({url})")
68
+ # else:
69
+ # lines.append(f"{i}. {title}")
70
+ # return out + "\n" + "\n".join(lines)
71
+ # import os
72
+ # from openai import OpenAI
73
+
74
+ # LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
75
+ # LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
76
+
77
+ # def _first_k_chars(text, k=1200):
78
+ # t = text.strip()
79
+ # return t[:k] + ("…" if len(t) > k else "")
80
+
81
+ # def _build_prompt(query, passages):
82
+ # chunks = []
83
+ # for i, p in enumerate(passages, 1):
84
+ # txt = p["payload"].get("text") or ""
85
+ # chunks.append(f"[{i}] {_first_k_chars(txt)}")
86
+
87
+ # # def _build_prompt(query, passages):
88
+ # # chunks = []
89
+ # # for i, p in enumerate(passages, 1):
90
+ # # txt = p["payload"].get("text") or ""
91
+ # # chunks.append(f"[{i}] {txt}")
92
+ # context = "\n\n".join(chunks)
93
+
94
+ # return f"""Tu es un assistant RH de la fonction publique française.
95
+ # - Réponds de manière factuelle et concise.
96
+ # - Cite tes sources en fin de phrase avec [n] correspondant aux extraits ci-dessous.
97
+ # - Si l’information n’est pas dans les sources, réponds : “Je ne sais pas”.
98
+ # - Ne fabrique pas de liens ni de références.
99
+
100
+ # Question: {query}
101
+
102
+ # Extraits indexés:
103
+ # {context}
104
+
105
+ # Réponse:"""
106
+
107
+ # def synth_answer_stream(query, passages):
108
+ # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
109
+ # prompt = _build_prompt(query, passages)
110
+
111
+ # # ✅ Correct streaming usage
112
+ # stream = client.chat.completions.create(
113
+ # model=LLM_MODEL,
114
+ # messages=[{"role": "user", "content": prompt}],
115
+ # temperature=0.2,
116
+ # stream=True, # <- this is key
117
+ # )
118
+ # for chunk in stream:
119
+ # delta = getattr(chunk.choices[0].delta, "content", None)
120
+ # if delta:
121
+ # acc.append(delta)
122
+ # yield delta # stream piece by piece
123
+ # # def synth_answer(query, passages):
124
+ # # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=LLM_BASE_URL)
125
+ # # prompt = _build_prompt(query, passages)
126
+
127
+ # # resp = client.chat.completions.create(
128
+ # # model=LLM_MODEL,
129
+ # # messages=[{"role": "user", "content": prompt}],
130
+ # # temperature=0.2,
131
+ # # )
132
+ # # return resp.choices[0].message.content.strip()
133
+
134
+ # # --- HELPERS
135
+
136
+ # def render_sources(passages):
137
+ # lines = []
138
+ # for i, p in enumerate(passages, 1):
139
+ # pl = p["payload"]
140
+ # title = (pl.get("title") or "Source").strip()
141
+ # url = pl.get("url") or ""
142
+ # lines.append(f"[{i}] {title}" + (f" — {url}" if url else ""))
143
+ # return "\n".join(lines)
144
+
145
+ # def linkify(text, passages):
146
+ # # turn [1] -> markdown link when url exists
147
+ # for i, p in enumerate(passages, 1):
148
+ # url = p["payload"].get("url")
149
+ # if url:
150
+ # text = text.replace(f"[{i}]", f"[{i}]({url})")
151
+ # return text
152
+
153
+
154
+
155
+
156
+
157
+
requirements.txt CHANGED
@@ -3,4 +3,5 @@ datasets>=2.19.0
3
  huggingface-hub>=0.20
4
  faiss-cpu==1.7.4
5
  numpy<2
6
- python-dotenv
 
 
3
  huggingface-hub>=0.20
4
  faiss-cpu==1.7.4
5
  numpy<2
6
+ python-dotenv
7
+ openai>=1.0.0