VatsalPatel18 commited on
Commit
1d1ab79
·
1 Parent(s): 2f08951

Fix Phi-3 generate fallback on ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +40 -6
app.py CHANGED
@@ -211,16 +211,38 @@ class GeneratorWrapper:
211
  def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
212
  pipe = self.ensure()
213
  streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)
214
- kwargs = {
 
 
 
215
  "max_new_tokens": max_new_tokens,
216
  "do_sample": True,
217
  "temperature": temperature,
218
  "top_p": top_p,
219
  "streamer": streamer,
220
- "return_full_text": False,
 
 
221
  }
222
- thread = Thread(target=pipe, args=(prompt,), kwargs=kwargs)
223
- thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  if self._note:
225
  yield self._note + " "
226
  self._note = None
@@ -243,13 +265,25 @@ def load_tinyllama():
243
 
244
 
245
  def load_phi3_mini():
246
- return pipeline(
247
  "text-generation",
248
  model="microsoft/Phi-3-mini-4k-instruct",
249
  device_map="cpu",
250
  torch_dtype=torch.float32,
251
  trust_remote_code=True,
 
 
 
 
252
  )
 
 
 
 
 
 
 
 
253
 
254
 
255
  _tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
@@ -438,7 +472,7 @@ with gr.Blocks(title="MedDiscover") as demo:
438
  model_dd = gr.Dropdown(
439
  label="Generator Model",
440
  choices=list(GENERATORS.keys()),
441
- value="phi-3-mini-4k",
442
  interactive=True,
443
  )
444
  k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks")
 
211
  def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
212
  pipe = self.ensure()
213
  streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)
214
+ inputs = pipe.tokenizer(prompt, return_tensors="pt")
215
+ device = getattr(pipe.model, "device", torch.device("cpu"))
216
+ inputs = {k: v.to(device) for k, v in inputs.items()}
217
+ gen_kwargs = {
218
  "max_new_tokens": max_new_tokens,
219
  "do_sample": True,
220
  "temperature": temperature,
221
  "top_p": top_p,
222
  "streamer": streamer,
223
+ "return_dict_in_generate": True,
224
+ "output_scores": False,
225
+ "use_cache": False, # avoid DynamicCache issues on Phi-3 CPU
226
  }
227
+
228
+ def _run():
229
+ try:
230
+ pipe.model.generate(**inputs, **gen_kwargs)
231
+ except Exception as exc:
232
+ if self._fallback:
233
+ print(f"[Generator:{self.name}] generate failed: {exc}; falling back to {self._fallback.name}")
234
+ self._pipe = self._fallback.ensure()
235
+ note = self._fallback_msg or f"Falling back to {self._fallback.name}."
236
+ if note:
237
+ streamer.put(note + " ")
238
+ fb_stream = self._fallback.generate_stream(prompt, max_new_tokens, temperature, top_p)
239
+ for tok in fb_stream:
240
+ streamer.put(tok)
241
+ else:
242
+ print(f"[Generator:{self.name}] generate failed: {exc}")
243
+ streamer.end()
244
+
245
+ Thread(target=_run, daemon=True).start()
246
  if self._note:
247
  yield self._note + " "
248
  self._note = None
 
265
 
266
 
267
  def load_phi3_mini():
268
+ pipe = pipeline(
269
  "text-generation",
270
  model="microsoft/Phi-3-mini-4k-instruct",
271
  device_map="cpu",
272
  torch_dtype=torch.float32,
273
  trust_remote_code=True,
274
+ model_kwargs={
275
+ "use_cache": False,
276
+ "attn_implementation": "eager",
277
+ },
278
  )
279
+ # Disable cache to avoid DynamicCache.seen_tokens errors on ZeroGPU/CPU.
280
+ try:
281
+ pipe.model.config.use_cache = False
282
+ pipe.model.generation_config.use_cache = False
283
+ pipe.model.generation_config.cache_implementation = "static"
284
+ except Exception:
285
+ pass
286
+ return pipe
287
 
288
 
289
  _tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
 
472
  model_dd = gr.Dropdown(
473
  label="Generator Model",
474
  choices=list(GENERATORS.keys()),
475
+ value="tinyllama-1.1b-chat",
476
  interactive=True,
477
  )
478
  k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks")