Chantal commited on
Commit
8fdae9d
·
verified ·
1 Parent(s): 30d9f56

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +133 -0
README.md CHANGED
@@ -17,3 +17,136 @@ library_name: transformers
17
  RaDialog
18
  </h1>
19
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  RaDialog
18
  </h1>
19
  </div>
20
+
21
+ <p align="center">
22
+ 📝 <a href="https://arxiv.org/abs/2311.18681" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Chantal/RaDialog-interactive-radiology-report-generation/" target="_blank">Hugging Face</a> • <a href="https://github.com/ChantalMP/RaDialog" target="_blank">Github</a> • <a href="https://physionet.org/content/radialog-instruct-dataset/1.0.0/" target="_blank">Dataset</a>
23
+ </p>
24
+
25
+ <div align="center">
26
+ </div>
27
+
28
+ ## Get Started
29
+
30
+ Clone repository:
31
+ ```python
32
+ git clone https://huggingface.co/Chantal/RaDialog-interactive-radiology-report-generation
33
+ ```
34
+
35
+ Install requirements:
36
+ ```python
37
+ pip install requirements.txt
38
+ ```
39
+
40
+ Run RaDialog inference:
41
+ ```python
42
+ from pathlib import Path
43
+
44
+ import io
45
+
46
+ import requests
47
+ import torch
48
+ from PIL import Image
49
+ import numpy as np
50
+ from huggingface_hub import snapshot_download
51
+
52
+ from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
53
+ from LLAVA_Biovil.llava.model.builder import load_pretrained_model
54
+ from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
55
+
56
+ from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
57
+ from utils import create_chest_xray_transform_for_inference
58
+
59
+ def load_model_from_huggingface(repo_id):
60
+ # Download model files
61
+ model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
62
+ model_path = Path(model_path)
63
+
64
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
65
+ model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
66
+
67
+ return tokenizer, model, image_processor, context_len
68
+
69
+ tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
70
+
71
+
72
+ if __name__ == '__main__':
73
+ # config = None
74
+ # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
75
+ # model_name = get_model_name_from_path(model_path)
76
+ model.config.tokenizer_padding_side = "left"
77
+
78
+ findings = "edema, pleural effusion" #TODO should these come from chexpert classifier? Or not needed for this demo/test?
79
+
80
+ conv = conv_vicuna_v1.copy()
81
+ REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
82
+ print("USER: ", REPORT_GEN_PROMPT)
83
+ conv.append_message("USER", REPORT_GEN_PROMPT)
84
+ conv.append_message("ASSISTANT", None)
85
+ text_input = conv.get_prompt()
86
+
87
+ # get the image
88
+ vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
89
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
90
+
91
+ response = requests.get(sample_img_path)
92
+ image = Image.open(io.BytesIO(response.content))
93
+ image = remap_to_uint8(np.array(image))
94
+ image = Image.fromarray(image).convert("L")
95
+ image_tensor = vis_transforms_biovil(image).unsqueeze(0)
96
+
97
+ image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
98
+ input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
99
+
100
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
101
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
102
+
103
+ # generate a report
104
+ with torch.inference_mode():
105
+ output_ids = model.generate(
106
+ input_ids,
107
+ images=image_tensor,
108
+ do_sample=False,
109
+ use_cache=True,
110
+ max_new_tokens=300,
111
+ stopping_criteria=[stopping_criteria],
112
+ pad_token_id=tokenizer.pad_token_id
113
+ )
114
+
115
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
116
+ print("ASSISTANT: ", pred)
117
+
118
+ # add prediction to conversation
119
+ conv.messages.pop()
120
+ conv.append_message("ASSISTANT", pred)
121
+ conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
122
+ conv.append_message("ASSISTANT", None)
123
+ text_input = conv.get_prompt()
124
+ print("USER: ", "Translate this report to easy language for a patient to understand.")
125
+
126
+ # generate easy language report
127
+ input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
128
+ with torch.inference_mode():
129
+ output_ids = model.generate(
130
+ input_ids,
131
+ images=image_tensor,
132
+ do_sample=False,
133
+ use_cache=True,
134
+ max_new_tokens=300,
135
+ stopping_criteria=[stopping_criteria],
136
+ pad_token_id=tokenizer.pad_token_id
137
+ )
138
+
139
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
140
+ print("ASSISTANT: ", pred)
141
+ ```
142
+
143
+ ## ✏️ Citation
144
+
145
+ ```
146
+ @article{pellegrini2023radialog,
147
+ title={RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance},
148
+ author={Pellegrini, Chantal and {\"O}zsoy, Ege and Busam, Benjamin and Navab, Nassir and Keicher, Matthias},
149
+ journal={arXiv preprint arXiv:2311.18681},
150
+ year={2023}
151
+ }
152
+ ```