Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +5 -1
CXR_LLAVA_HF.py
CHANGED
@@ -10,6 +10,7 @@ from threading import Thread
|
|
10 |
from dataclasses import dataclass
|
11 |
import numpy as np
|
12 |
from PIL import Image
|
|
|
13 |
# Model Constants
|
14 |
IGNORE_INDEX = -100
|
15 |
IMAGE_TOKEN_INDEX = -200
|
@@ -597,7 +598,7 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
597 |
|
598 |
def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
|
599 |
with torch.no_grad():
|
600 |
-
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=
|
601 |
|
602 |
if np.array(image).max()>255:
|
603 |
raise Exception("16-bit image is not supported.")
|
@@ -636,8 +637,11 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
636 |
))
|
637 |
thread.start()
|
638 |
generated_text = ""
|
|
|
639 |
for new_text in streamer:
|
640 |
generated_text += new_text
|
|
|
|
|
641 |
|
642 |
return generated_text
|
643 |
|
|
|
10 |
from dataclasses import dataclass
|
11 |
import numpy as np
|
12 |
from PIL import Image
|
13 |
+
|
14 |
# Model Constants
|
15 |
IGNORE_INDEX = -100
|
16 |
IMAGE_TOKEN_INDEX = -200
|
|
|
598 |
|
599 |
def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
|
600 |
with torch.no_grad():
|
601 |
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=180)
|
602 |
|
603 |
if np.array(image).max()>255:
|
604 |
raise Exception("16-bit image is not supported.")
|
|
|
637 |
))
|
638 |
thread.start()
|
639 |
generated_text = ""
|
640 |
+
text_len = 0
|
641 |
for new_text in streamer:
|
642 |
generated_text += new_text
|
643 |
+
text_len += 1
|
644 |
+
if text_len > 200: break
|
645 |
|
646 |
return generated_text
|
647 |
|