VatsalPatel18 commited on
Commit
2f08951
·
1 Parent(s): a39a4c5

Force phi-3 to CPU and fall back to TinyLlama

Browse files
Files changed (3) hide show
  1. NOTE.md +1 -0
  2. README.md +1 -0
  3. app.py +30 -9
NOTE.md CHANGED
@@ -7,6 +7,7 @@
7
  - Streaming answers; prompt forces context-grounded responses.
8
  - Uses `@spaces.GPU()` on heavy steps; caches/models under `/data/.cache` (`HF_HOME`).
9
  - No reranker; straight FAISS retrieval.
 
10
  - Dependencies captured in `requirements.txt`; Spaces front matter in `README.md`.
11
  - Pushed to HF Space `VatsalPatel18/MedDisover-space` (commit f6d2d7d).
12
  - To do next when resuming:
 
7
  - Streaming answers; prompt forces context-grounded responses.
8
  - Uses `@spaces.GPU()` on heavy steps; caches/models under `/data/.cache` (`HF_HOME`).
9
  - No reranker; straight FAISS retrieval.
10
+ - ZeroGPU safeguard: Phi-3-mini-4k forced to CPU and will auto-fall back to TinyLlama if CUDA init is blocked in stateless GPU runtime.
11
  - Dependencies captured in `requirements.txt`; Spaces front matter in `README.md`.
12
  - Pushed to HF Space `VatsalPatel18/MedDisover-space` (commit f6d2d7d).
13
  - To do next when resuming:
README.md CHANGED
@@ -19,3 +19,4 @@ A Hugging Face Spaces-ready Gradio app that runs MedDiscover-style RAG (MedCPT e
19
  - No external API keys required; optionally use `HF_TOKEN` secret if models need auth.
20
  - Chunking: 500 words with 50-word overlap; MedCPT encoder for embeddings; FAISS IP index.
21
  - Generation is context-grounded; drop-down selects generator model.
 
 
19
  - No external API keys required; optionally use `HF_TOKEN` secret if models need auth.
20
  - Chunking: 500 words with 50-word overlap; MedCPT encoder for embeddings; FAISS IP index.
21
  - Generation is context-grounded; drop-down selects generator model.
22
+ - Phi-3-mini-4k is forced to CPU on Spaces; if loading fails (stateless GPU), the app auto-falls back to TinyLlama.
app.py CHANGED
@@ -185,14 +185,27 @@ def search(index: faiss.IndexFlatIP, meta: List[Dict], query_vec: np.ndarray, k:
185
  # Model registry for generators
186
  # ----------------------------
187
  class GeneratorWrapper:
188
- def __init__(self, name: str, load_fn):
189
  self.name = name
190
  self._load_fn = load_fn
191
  self._pipe = None
 
 
 
192
 
193
  def ensure(self):
194
  if self._pipe is None:
195
- self._pipe = self._load_fn()
 
 
 
 
 
 
 
 
 
 
196
  return self._pipe
197
 
198
  def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
@@ -208,6 +221,9 @@ class GeneratorWrapper:
208
  }
209
  thread = Thread(target=pipe, args=(prompt,), kwargs=kwargs)
210
  thread.start()
 
 
 
211
  for token in streamer:
212
  yield token
213
 
@@ -227,21 +243,26 @@ def load_tinyllama():
227
 
228
 
229
  def load_phi3_mini():
230
- use_cuda = torch.cuda.is_available()
231
- device_map = "cuda" if use_cuda else "cpu"
232
- dtype = torch.float16 if use_cuda else torch.float32
233
  return pipeline(
234
  "text-generation",
235
  model="microsoft/Phi-3-mini-4k-instruct",
236
- device_map=device_map,
237
- torch_dtype=dtype,
238
  trust_remote_code=True,
239
  )
240
 
241
 
 
 
 
 
 
 
 
 
242
  GENERATORS = {
243
- "tinyllama-1.1b-chat": GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama),
244
- "phi-3-mini-4k": GeneratorWrapper("phi-3-mini-4k", load_phi3_mini),
245
  }
246
 
247
 
 
185
  # Model registry for generators
186
  # ----------------------------
187
  class GeneratorWrapper:
188
+ def __init__(self, name: str, load_fn, fallback=None, fallback_msg: str | None = None):
189
  self.name = name
190
  self._load_fn = load_fn
191
  self._pipe = None
192
+ self._fallback = fallback
193
+ self._fallback_msg = fallback_msg
194
+ self._note = None
195
 
196
  def ensure(self):
197
  if self._pipe is None:
198
+ try:
199
+ self._pipe = self._load_fn()
200
+ self._note = None
201
+ except Exception as exc:
202
+ print(f"[Generator:{self.name}] load failed: {exc}")
203
+ if self._fallback:
204
+ print(f"[Generator:{self.name}] falling back to {self._fallback.name}")
205
+ self._pipe = self._fallback.ensure()
206
+ self._note = self._fallback_msg or f"Falling back to {self._fallback.name}."
207
+ else:
208
+ raise
209
  return self._pipe
210
 
211
  def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
 
221
  }
222
  thread = Thread(target=pipe, args=(prompt,), kwargs=kwargs)
223
  thread.start()
224
+ if self._note:
225
+ yield self._note + " "
226
+ self._note = None
227
  for token in streamer:
228
  yield token
229
 
 
243
 
244
 
245
  def load_phi3_mini():
 
 
 
246
  return pipeline(
247
  "text-generation",
248
  model="microsoft/Phi-3-mini-4k-instruct",
249
+ device_map="cpu",
250
+ torch_dtype=torch.float32,
251
  trust_remote_code=True,
252
  )
253
 
254
 
255
+ _tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
256
+ _phi_wrapper = GeneratorWrapper(
257
+ "phi-3-mini-4k",
258
+ load_phi3_mini,
259
+ fallback=_tiny_wrapper,
260
+ fallback_msg="Phi-3-mini-4k unavailable on this Space (CUDA blocked); falling back to TinyLlama CPU.",
261
+ )
262
+
263
  GENERATORS = {
264
+ "tinyllama-1.1b-chat": _tiny_wrapper,
265
+ "phi-3-mini-4k": _phi_wrapper,
266
  }
267
 
268