rkonan commited on
Commit
2004481
·
1 Parent(s): 180c827

version ollama

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. app_ollama.py +123 -0
  3. rag_model_ollama.py +284 -0
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import streamlit as st
2
  from llama_cpp import Llama
3
  import os
4
- #from rag_model import RAGEngine
5
 
6
- from rag_model_optimise import RAGEngine
7
  import logging
8
  from huggingface_hub import hf_hub_download
9
  import time
 
1
  import streamlit as st
2
  from llama_cpp import Llama
3
  import os
4
+ from rag_model import RAGEngine
5
 
6
+ #from rag_model_optimise import RAGEngine
7
  import logging
8
  from huggingface_hub import hf_hub_download
9
  import time
app_ollama.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import logging
4
+ import streamlit as st
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # ✅ Nouveau moteur RAG (Ollama)
9
+ from rag_model_ollama import RAGEngine
10
+
11
+ # --- Config & logs ---
12
+ os.environ.setdefault("NLTK_DATA", "/home/appuser/nltk_data")
13
+
14
+ logger = logging.getLogger("Streamlit")
15
+ logger.setLevel(logging.INFO)
16
+ handler = logging.StreamHandler()
17
+ formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
18
+ handler.setFormatter(formatter)
19
+ if not logger.handlers:
20
+ logger.addHandler(handler)
21
+
22
+ st.set_page_config(page_title="Chatbot RAG (Ollama)", page_icon="🤖")
23
+
24
+ # --- ENV ---
25
+ ENV = os.getenv("ENV", "local") # "local" ou "space"
26
+ logger.info(f"ENV: {ENV}")
27
+
28
+ # --- Chemins FAISS & chunks ---
29
+ if ENV == "local":
30
+ # Adapte ces chemins à ton filesystem local
31
+ faiss_index_path = "chatbot-models/vectordb_docling/index.faiss"
32
+ vectors_path = "chatbot-models/vectordb_docling/chunks.pkl"
33
+ else:
34
+ # Télécharge depuis Hugging Face (dataset privé/public selon tes réglages)
35
+ faiss_index_path = hf_hub_download(
36
+ repo_id="rkonan/chatbot-models",
37
+ filename="chatbot-models/vectordb_docling/index.faiss",
38
+ repo_type="dataset"
39
+ )
40
+ vectors_path = hf_hub_download(
41
+ repo_id="rkonan/chatbot-models",
42
+ filename="chatbot-models/vectordb_docling/chunks.pkl",
43
+ repo_type="dataset"
44
+ )
45
+
46
+ # --- UI Sidebar ---
47
+ st.sidebar.header("⚙️ Paramètres")
48
+ default_host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
49
+ ollama_host = st.sidebar.text_input("Ollama host", value=default_host, help="Ex: http://localhost:11434")
50
+
51
+ # Propose des modèles déjà présents ou courants
52
+ suggested_models = [
53
+ "mistral", # présent chez toi
54
+ "gemma3", # présent chez toi
55
+ "deepseek-r1", # présent chez toi (raisonnement long, plus lent)
56
+ "granite3.3", # présent chez toi
57
+ "llama3.1:8b-instruct-q4_K_M",
58
+ "nous-hermes2:Q4_K_M",
59
+ ]
60
+ model_name = st.sidebar.selectbox("Modèle Ollama", options=suggested_models, index=0)
61
+ num_threads = st.sidebar.slider("Threads (hint)", min_value=2, max_value=16, value=6, step=1)
62
+ temperature = st.sidebar.slider("Température", min_value=0.0, max_value=1.5, value=0.1, step=0.1)
63
+
64
+ st.title("🤖 Chatbot RAG Local (Ollama)")
65
+
66
+ # --- Cache du moteur ---
67
+ @st.cache_resource(show_spinner=True)
68
+ def load_rag_engine(_model_name: str, _host: str, _threads: int, _temp: float):
69
+ # Options pour Ollama
70
+ ollama_opts = {
71
+ "num_thread": int(_threads),
72
+ "temperature": float(_temp),
73
+ }
74
+
75
+ rag = RAGEngine(
76
+ model_name=_model_name,
77
+ vector_path=vectors_path,
78
+ index_path=faiss_index_path,
79
+ model_threads=_threads,
80
+ ollama_host=_host,
81
+ ollama_opts=ollama_opts
82
+ )
83
+
84
+ # Warmup léger (évite la latence au 1er token)
85
+ try:
86
+ _ = rag._complete("Bonjour", max_tokens=1)
87
+ except Exception as e:
88
+ logger.warning(f"Warmup Ollama échoué: {e}")
89
+ return rag
90
+
91
+ rag = load_rag_engine(model_name, ollama_host, num_threads, temperature)
92
+
93
+ # --- Chat simple ---
94
+ user_input = st.text_area("Posez votre question :", height=120, placeholder="Ex: Quels sont les traitements appliqués aux images ?")
95
+ col1, col2 = st.columns([1,1])
96
+
97
+ if col1.button("Envoyer"):
98
+ if user_input.strip():
99
+ with st.spinner("Génération en cours..."):
100
+ try:
101
+ response = rag.ask(user_input)
102
+ st.markdown("**Réponse :**")
103
+ st.success(response)
104
+ except Exception as e:
105
+ st.error(f"Erreur pendant la génération: {e}")
106
+ else:
107
+ st.info("Saisissez une question.")
108
+
109
+ if col2.button("Envoyer (stream)"):
110
+ if user_input.strip():
111
+ with st.spinner("Génération en cours (stream)..."):
112
+ try:
113
+ # Affichage token-par-token
114
+ ph = st.empty()
115
+ acc = ""
116
+ for token in rag.ask_stream(user_input):
117
+ acc += token
118
+ ph.markdown(acc)
119
+ st.balloons()
120
+ except Exception as e:
121
+ st.error(f"Erreur pendant la génération (stream): {e}")
122
+ else:
123
+ st.info("Saisissez une question.")
rag_model_ollama.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import pickle
4
+ import textwrap
5
+ import logging
6
+ from typing import List, Optional, Dict, Any, Iterable
7
+
8
+ import requests
9
+ import faiss
10
+ import numpy as np
11
+ from llama_index.core import VectorStoreIndex
12
+ from llama_index.core.schema import TextNode
13
+ from llama_index.vector_stores.faiss import FaissVectorStore
14
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
15
+ from sentence_transformers.util import cos_sim
16
+
17
+
18
+ # === Logger configuration ===
19
+ logger = logging.getLogger("RAGEngine")
20
+ logger.setLevel(logging.INFO)
21
+ handler = logging.StreamHandler()
22
+ formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
23
+ handler.setFormatter(formatter)
24
+ if not logger.handlers:
25
+ logger.addHandler(handler)
26
+
27
+ MAX_TOKENS = 512
28
+
29
+
30
+ class OllamaClient:
31
+ """
32
+ Minimal Ollama client for /api/generate (text completion) with streaming support.
33
+ Docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
34
+ """
35
+ def __init__(self, model: str, host: Optional[str] = None, timeout: int = 120):
36
+ self.model = model
37
+ self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
38
+ self.timeout = timeout
39
+ self._gen_url = self.host.rstrip("/") + "/api/generate"
40
+
41
+ def generate(
42
+ self,
43
+ prompt: str,
44
+ stop: Optional[List[str]] = None,
45
+ max_tokens: Optional[int] = None,
46
+ stream: bool = False,
47
+ options: Optional[Dict[str, Any]] = None,
48
+ ) -> str | Iterable[str]:
49
+ payload = {
50
+ "model": self.model,
51
+ "prompt": prompt,
52
+ "stream": stream,
53
+ }
54
+ if stop:
55
+ payload["stop"] = stop
56
+ if max_tokens is not None:
57
+ # Ollama uses "num_predict" for max new tokens
58
+ payload["num_predict"] = int(max_tokens)
59
+ if options:
60
+ payload["options"] = options
61
+
62
+ logger.debug(f"POST {self._gen_url} (stream={stream})")
63
+
64
+ if stream:
65
+ with requests.post(self._gen_url, json=payload, stream=True, timeout=self.timeout) as r:
66
+ r.raise_for_status()
67
+ for line in r.iter_lines(decode_unicode=True):
68
+ if not line:
69
+ continue
70
+ try:
71
+ data = json.loads(line)
72
+ except Exception:
73
+ # In case a broken line appears
74
+ continue
75
+ if "response" in data and data.get("done") is not True:
76
+ yield data["response"]
77
+ if data.get("done"):
78
+ break
79
+ return
80
+
81
+ # Non-streaming
82
+ r = requests.post(self._gen_url, json=payload, timeout=self.timeout)
83
+ r.raise_for_status()
84
+ data = r.json()
85
+ return data.get("response", "")
86
+
87
+
88
+ # Lazy import json to keep top clean
89
+ import json
90
+
91
+
92
+ class RAGEngine:
93
+ def __init__(
94
+ self,
95
+ model_name: str,
96
+ vector_path: str,
97
+ index_path: str,
98
+ model_threads: int = 4,
99
+ ollama_host: Optional[str] = None,
100
+ ollama_opts: Optional[Dict[str, Any]] = None,
101
+ ):
102
+ """
103
+ Args:
104
+ model_name: e.g. "nous-hermes2:Q4_K_M" or "llama3.1:8b-instruct-q4_K_M"
105
+ vector_path: pickle file with chunk texts list[str]
106
+ index_path: FAISS index path
107
+ model_threads: forwarded to Ollama via options.n_threads (if supported by the model)
108
+ ollama_host: override OLLAMA_HOST (default http://localhost:11434)
109
+ ollama_opts: extra Ollama options (e.g., temperature, top_p, num_gpu, num_thread)
110
+ """
111
+ logger.info("📦 Initialisation du moteur RAG (Ollama)...")
112
+ # Build options
113
+ opts = dict(ollama_opts or {})
114
+ # Common low-latency defaults; user can override via ollama_opts
115
+ opts.setdefault("temperature", 0.1)
116
+ # Try to pass thread hint if supported by the backend
117
+ if "num_thread" not in opts and model_threads:
118
+ opts["num_thread"] = int(model_threads)
119
+
120
+ self.llm = OllamaClient(model=model_name, host=ollama_host)
121
+ self.ollama_opts = opts
122
+
123
+ self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
124
+
125
+ logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
126
+ with open(vector_path, "rb") as f:
127
+ chunk_texts = pickle.load(f)
128
+ nodes = [TextNode(text=chunk) for chunk in chunk_texts]
129
+
130
+ faiss_index = faiss.read_index(index_path)
131
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
132
+ self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
133
+
134
+ logger.info("✅ Moteur RAG (Ollama) initialisé avec succès.")
135
+
136
+ # ---------------- LLM helpers (via Ollama) ----------------
137
+
138
+ def _complete(self, prompt: str, stop: Optional[List[str]] = None, max_tokens: int = 128) -> str:
139
+ text = self.llm.generate(
140
+ prompt=prompt,
141
+ stop=stop,
142
+ max_tokens=max_tokens,
143
+ stream=False,
144
+ options=self.ollama_opts,
145
+ )
146
+ # Some Ollama setups may stream even when stream=False. Coerce generators to string.
147
+ try:
148
+ if hasattr(text, "__iter__") and not isinstance(text, (str, bytes)):
149
+ chunks = []
150
+ for t in text:
151
+ if not isinstance(t, (str, bytes)):
152
+ continue
153
+ chunks.append(t)
154
+ text = "".join(chunks)
155
+ except Exception:
156
+ pass
157
+ return (text or "").strip()
158
+
159
+ def _complete_stream(self, prompt: str, stop: Optional[List[str]] = None, max_tokens: int = MAX_TOKENS):
160
+ return self.llm.generate(
161
+ prompt=prompt,
162
+ stop=stop,
163
+ max_tokens=max_tokens,
164
+ stream=True,
165
+ options=self.ollama_opts,
166
+ )
167
+
168
+ # ---------------- Reformulation ----------------
169
+
170
+ def reformulate_question(self, question: str) -> str:
171
+ logger.info("🔁 Reformulation de la question (sans contexte)...")
172
+ prompt = f"""Tu es un assistant expert chargé de clarifier des questions floues.
173
+
174
+ Transforme la question suivante en une question claire, explicite et complète, sans ajouter d'informations extérieures.
175
+
176
+ Question floue : {question}
177
+ Question reformulée :"""
178
+ reformulated = self._complete(prompt, stop=["\n"], max_tokens=128)
179
+ logger.info(f"📝 Reformulée : {reformulated}")
180
+ return reformulated
181
+
182
+ def reformulate_with_context(self, question: str, context_sample: str) -> str:
183
+ logger.info("🔁 Reformulation de la question avec contexte...")
184
+ prompt = f"""Tu es un assistant expert en machine learning. Ton rôle est de reformuler les questions utilisateur en tenant compte du contexte ci-dessous, extrait d’un rapport technique sur un projet de reconnaissance de maladies de plantes.
185
+
186
+ Ta mission est de transformer une question vague ou floue en une question précise et adaptée au contenu du rapport. Ne donne pas une interprétation hors sujet. Ne reformule pas en termes de produits commerciaux.
187
+
188
+ Contexte :
189
+ {context_sample}
190
+
191
+ Question initiale : {question}
192
+ Question reformulée :"""
193
+ reformulated = self._complete(prompt, stop=["\n"], max_tokens=128)
194
+ logger.info(f"📝 Reformulée avec contexte : {reformulated}")
195
+ return reformulated
196
+
197
+ # ---------------- Retrieval ----------------
198
+
199
+ def get_adaptive_top_k(self, question: str) -> int:
200
+ q = question.lower()
201
+ if len(q.split()) <= 7:
202
+ top_k = 8
203
+ elif any(w in q for w in ["liste", "résume", "quels sont", "explique", "comment"]):
204
+ top_k = 10
205
+ else:
206
+ top_k = 8
207
+ logger.info(f"🔢 top_k déterminé automatiquement : {top_k}")
208
+ return top_k
209
+
210
+ def rerank_nodes(self, question: str, retrieved_nodes, top_k: int = 3):
211
+ logger.info(f"🔍 Re-ranking des {len(retrieved_nodes)} chunks pour la question : « {question} »")
212
+ q_emb = self.embed_model.get_query_embedding(question)
213
+ scored_nodes = []
214
+
215
+ for node in retrieved_nodes:
216
+ chunk_text = node.get_content()
217
+ chunk_emb = self.embed_model.get_text_embedding(chunk_text)
218
+ score = cos_sim(q_emb, chunk_emb).item()
219
+ scored_nodes.append((score, node))
220
+
221
+ ranked_nodes = sorted(scored_nodes, key=lambda x: x[0], reverse=True)
222
+
223
+ logger.info("📊 Chunks les plus pertinents :")
224
+ for i, (score, node) in enumerate(ranked_nodes[:top_k]):
225
+ chunk_preview = textwrap.shorten(node.get_content().replace("\n", " "), width=100)
226
+ logger.info(f"#{i+1} | Score: {score:.4f} | {chunk_preview}")
227
+
228
+ return [n for _, n in ranked_nodes[:top_k]]
229
+
230
+ def retrieve_context(self, question: str, top_k: int = 3):
231
+ logger.info(f"📥 Récupération du contexte...")
232
+ retriever = self.index.as_retriever(similarity_top_k=top_k)
233
+ retrieved_nodes = retriever.retrieve(question)
234
+ reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
235
+ context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
236
+ return context, reranked_nodes
237
+
238
+ # ---------------- Public API ----------------
239
+
240
+ def ask(self, question_raw: str) -> str:
241
+ logger.info(f"💬 Question reçue : {question_raw}")
242
+ if len(question_raw.split()) <= 100:
243
+ context_sample, _ = self.retrieve_context(question_raw, top_k=3)
244
+ reformulated = self.reformulate_with_context(question_raw, context_sample)
245
+ else:
246
+ reformulated = self.reformulate_question(question_raw)
247
+
248
+ logger.info(f"📝 Question reformulée : {reformulated}")
249
+ top_k = self.get_adaptive_top_k(reformulated)
250
+ context, _ = self.retrieve_context(reformulated, top_k)
251
+
252
+ prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
253
+
254
+ Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
255
+
256
+ Contexte :
257
+ {context}
258
+
259
+ Question : {reformulated}
260
+ ### Réponse:"""
261
+
262
+ response = self._complete(prompt, stop=["### Instruction:"], max_tokens=MAX_TOKENS)
263
+ response = response.strip().split("###")[0]
264
+ logger.info(f"🧠 Réponse générée : {response[:120]}{{'...' if len(response) > 120 else ''}}")
265
+ return response
266
+
267
+ def ask_stream(self, question: str):
268
+ logger.info(f"💬 [Stream] Question reçue : {question}")
269
+ top_k = self.get_adaptive_top_k(question)
270
+ context, _ = self.retrieve_context(question, top_k)
271
+
272
+ prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
273
+
274
+ Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
275
+
276
+ Contexte :
277
+ {context}
278
+
279
+ Question : {question}
280
+ ### Réponse:"""
281
+
282
+ logger.info("📡 Début du streaming de la réponse...")
283
+ for token in self._complete_stream(prompt, stop=["### Instruction:"], max_tokens=MAX_TOKENS):
284
+ print(token, end="", flush=True)