ga89tiy commited on
Commit
285e0fb
·
1 Parent(s): ca57734
Files changed (1) hide show
  1. README.md +25 -4
README.md CHANGED
@@ -37,7 +37,7 @@ Install requirements:
37
  conda create -n llava_hf python=3.10
38
  conda activate llava_hf
39
  conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
40
- pip install requirements.txt
41
  ```
42
 
43
  Run RaDialog inference:
@@ -66,7 +66,7 @@ def load_model_from_huggingface(repo_id):
66
  model_path = Path(model_path)
67
 
68
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
69
- model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
70
 
71
 
72
  return tokenizer, model, image_processor, context_len
@@ -74,7 +74,7 @@ def load_model_from_huggingface(repo_id):
74
 
75
 
76
  if __name__ == '__main__':
77
- sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
78
 
79
  response = requests.get(sample_img_path)
80
  image = Image.open(io.BytesIO(response.content))
@@ -95,7 +95,7 @@ if __name__ == '__main__':
95
  findings = ', '.join(findings).lower().strip()
96
 
97
  conv = conv_vicuna_v1.copy()
98
- REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
99
  print("USER: ", REPORT_GEN_PROMPT)
100
  conv.append_message("USER", REPORT_GEN_PROMPT)
101
  conv.append_message("ASSISTANT", None)
@@ -126,6 +126,27 @@ if __name__ == '__main__':
126
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
127
  print("ASSISTANT: ", pred)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # add prediction to conversation
130
  conv.messages.pop()
131
  conv.append_message("ASSISTANT", pred)
 
37
  conda create -n llava_hf python=3.10
38
  conda activate llava_hf
39
  conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
40
+ pip install -r requirements.txt
41
  ```
42
 
43
  Run RaDialog inference:
 
66
  model_path = Path(model_path)
67
 
68
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
69
+ model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, $
70
 
71
 
72
  return tokenizer, model, image_processor, context_len
 
74
 
75
 
76
  if __name__ == '__main__':
77
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"
78
 
79
  response = requests.get(sample_img_path)
80
  image = Image.open(io.BytesIO(response.content))
 
95
  findings = ', '.join(findings).lower().strip()
96
 
97
  conv = conv_vicuna_v1.copy()
98
+ REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predi$
99
  print("USER: ", REPORT_GEN_PROMPT)
100
  conv.append_message("USER", REPORT_GEN_PROMPT)
101
  conv.append_message("ASSISTANT", None)
 
126
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
127
  print("ASSISTANT: ", pred)
128
 
129
+ # add prediction to conversation
130
+ conv.messages.pop()
131
+ conv.append_message("ASSISTANT", pred)
132
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
133
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
134
+
135
+ # generate a report
136
+ with torch.inference_mode():
137
+ output_ids = model.generate(
138
+ input_ids,
139
+ images=image_tensor,
140
+ do_sample=False,
141
+ use_cache=True,
142
+ max_new_tokens=300,
143
+ stopping_criteria=[stopping_criteria],
144
+ pad_token_id=tokenizer.pad_token_id
145
+ )
146
+
147
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
148
+ print("ASSISTANT: ", pred)
149
+
150
  # add prediction to conversation
151
  conv.messages.pop()
152
  conv.append_message("ASSISTANT", pred)