jcsagar commited on
Commit
3da894e
·
verified ·
1 Parent(s): 2ceeadd

Update CXR_LLAVA_HF.py

Browse files
Files changed (1) hide show
  1. CXR_LLAVA_HF.py +5 -1
CXR_LLAVA_HF.py CHANGED
@@ -10,6 +10,7 @@ from threading import Thread
10
  from dataclasses import dataclass
11
  import numpy as np
12
  from PIL import Image
 
13
  # Model Constants
14
  IGNORE_INDEX = -100
15
  IMAGE_TOKEN_INDEX = -200
@@ -597,7 +598,7 @@ class CXRLLAVAModel(PreTrainedModel):
597
 
598
  def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
599
  with torch.no_grad():
600
- streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=90)
601
 
602
  if np.array(image).max()>255:
603
  raise Exception("16-bit image is not supported.")
@@ -636,8 +637,11 @@ class CXRLLAVAModel(PreTrainedModel):
636
  ))
637
  thread.start()
638
  generated_text = ""
 
639
  for new_text in streamer:
640
  generated_text += new_text
 
 
641
 
642
  return generated_text
643
 
 
10
  from dataclasses import dataclass
11
  import numpy as np
12
  from PIL import Image
13
+
14
  # Model Constants
15
  IGNORE_INDEX = -100
16
  IMAGE_TOKEN_INDEX = -200
 
598
 
599
  def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
600
  with torch.no_grad():
601
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=180)
602
 
603
  if np.array(image).max()>255:
604
  raise Exception("16-bit image is not supported.")
 
637
  ))
638
  thread.start()
639
  generated_text = ""
640
+ text_len = 0
641
  for new_text in streamer:
642
  generated_text += new_text
643
+ text_len += 1
644
+ if text_len > 200: break
645
 
646
  return generated_text
647