Spaces:
Sleeping
Sleeping
Commit
·
1d1ab79
1
Parent(s):
2f08951
Fix Phi-3 generate fallback on ZeroGPU
Browse files
app.py
CHANGED
|
@@ -211,16 +211,38 @@ class GeneratorWrapper:
|
|
| 211 |
def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
|
| 212 |
pipe = self.ensure()
|
| 213 |
streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
"max_new_tokens": max_new_tokens,
|
| 216 |
"do_sample": True,
|
| 217 |
"temperature": temperature,
|
| 218 |
"top_p": top_p,
|
| 219 |
"streamer": streamer,
|
| 220 |
-
"
|
|
|
|
|
|
|
| 221 |
}
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if self._note:
|
| 225 |
yield self._note + " "
|
| 226 |
self._note = None
|
|
@@ -243,13 +265,25 @@ def load_tinyllama():
|
|
| 243 |
|
| 244 |
|
| 245 |
def load_phi3_mini():
|
| 246 |
-
|
| 247 |
"text-generation",
|
| 248 |
model="microsoft/Phi-3-mini-4k-instruct",
|
| 249 |
device_map="cpu",
|
| 250 |
torch_dtype=torch.float32,
|
| 251 |
trust_remote_code=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
_tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
|
|
@@ -438,7 +472,7 @@ with gr.Blocks(title="MedDiscover") as demo:
|
|
| 438 |
model_dd = gr.Dropdown(
|
| 439 |
label="Generator Model",
|
| 440 |
choices=list(GENERATORS.keys()),
|
| 441 |
-
value="
|
| 442 |
interactive=True,
|
| 443 |
)
|
| 444 |
k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks")
|
|
|
|
| 211 |
def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
|
| 212 |
pipe = self.ensure()
|
| 213 |
streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
| 214 |
+
inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
| 215 |
+
device = getattr(pipe.model, "device", torch.device("cpu"))
|
| 216 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 217 |
+
gen_kwargs = {
|
| 218 |
"max_new_tokens": max_new_tokens,
|
| 219 |
"do_sample": True,
|
| 220 |
"temperature": temperature,
|
| 221 |
"top_p": top_p,
|
| 222 |
"streamer": streamer,
|
| 223 |
+
"return_dict_in_generate": True,
|
| 224 |
+
"output_scores": False,
|
| 225 |
+
"use_cache": False, # avoid DynamicCache issues on Phi-3 CPU
|
| 226 |
}
|
| 227 |
+
|
| 228 |
+
def _run():
|
| 229 |
+
try:
|
| 230 |
+
pipe.model.generate(**inputs, **gen_kwargs)
|
| 231 |
+
except Exception as exc:
|
| 232 |
+
if self._fallback:
|
| 233 |
+
print(f"[Generator:{self.name}] generate failed: {exc}; falling back to {self._fallback.name}")
|
| 234 |
+
self._pipe = self._fallback.ensure()
|
| 235 |
+
note = self._fallback_msg or f"Falling back to {self._fallback.name}."
|
| 236 |
+
if note:
|
| 237 |
+
streamer.put(note + " ")
|
| 238 |
+
fb_stream = self._fallback.generate_stream(prompt, max_new_tokens, temperature, top_p)
|
| 239 |
+
for tok in fb_stream:
|
| 240 |
+
streamer.put(tok)
|
| 241 |
+
else:
|
| 242 |
+
print(f"[Generator:{self.name}] generate failed: {exc}")
|
| 243 |
+
streamer.end()
|
| 244 |
+
|
| 245 |
+
Thread(target=_run, daemon=True).start()
|
| 246 |
if self._note:
|
| 247 |
yield self._note + " "
|
| 248 |
self._note = None
|
|
|
|
| 265 |
|
| 266 |
|
| 267 |
def load_phi3_mini():
|
| 268 |
+
pipe = pipeline(
|
| 269 |
"text-generation",
|
| 270 |
model="microsoft/Phi-3-mini-4k-instruct",
|
| 271 |
device_map="cpu",
|
| 272 |
torch_dtype=torch.float32,
|
| 273 |
trust_remote_code=True,
|
| 274 |
+
model_kwargs={
|
| 275 |
+
"use_cache": False,
|
| 276 |
+
"attn_implementation": "eager",
|
| 277 |
+
},
|
| 278 |
)
|
| 279 |
+
# Disable cache to avoid DynamicCache.seen_tokens errors on ZeroGPU/CPU.
|
| 280 |
+
try:
|
| 281 |
+
pipe.model.config.use_cache = False
|
| 282 |
+
pipe.model.generation_config.use_cache = False
|
| 283 |
+
pipe.model.generation_config.cache_implementation = "static"
|
| 284 |
+
except Exception:
|
| 285 |
+
pass
|
| 286 |
+
return pipe
|
| 287 |
|
| 288 |
|
| 289 |
_tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
|
|
|
|
| 472 |
model_dd = gr.Dropdown(
|
| 473 |
label="Generator Model",
|
| 474 |
choices=list(GENERATORS.keys()),
|
| 475 |
+
value="tinyllama-1.1b-chat",
|
| 476 |
interactive=True,
|
| 477 |
)
|
| 478 |
k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks")
|