Rohith1112 commited on
Commit
aff94f6
·
verified ·
1 Parent(s): 28f8c56
Files changed (1) hide show
  1. app.py +79 -32
app.py CHANGED
@@ -4,9 +4,11 @@ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
 
7
 
8
  # Model setup
9
- device = torch.device('cpu') # Use 'cuda' if GPU is available
10
  dtype = torch.float32
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
  proj_out_num = 256
@@ -14,8 +16,8 @@ proj_out_num = 256
14
  # Load model and tokenizer
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name_or_path,
17
- torch_dtype=torch.float32,
18
- device_map='cpu',
19
  trust_remote_code=True
20
  )
21
 
@@ -27,43 +29,50 @@ tokenizer = AutoTokenizer.from_pretrained(
27
  trust_remote_code=True
28
  )
29
 
 
 
 
30
  # Chat history storage
31
  chat_history = []
32
  current_image = None
33
 
34
  def extract_and_display_images(image_path):
35
- npy_data = np.load(image_path)
36
- if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
- npy_data = npy_data[0]
38
- elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
- return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
-
41
- fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
- for i, ax in enumerate(axes.flat):
43
- ax.imshow(npy_data[i], cmap='gray')
44
- ax.axis('off')
45
-
46
- image_output = "extracted_images.png"
47
- plt.savefig(image_output, bbox_inches='tight')
48
- plt.close()
49
- return image_output
50
-
 
 
51
 
52
  def process_image(question):
53
  global current_image
54
  if current_image is None:
55
  return "Please upload an image first."
56
 
57
- image_np = np.load(current_image)
58
- image_tokens = "<im_patch>" * proj_out_num
59
- input_txt = image_tokens + question
60
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
-
62
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
63
- generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
- generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
- return generated_texts[0]
66
-
 
 
67
 
68
  def chat_interface(question):
69
  global chat_history
@@ -71,16 +80,51 @@ def chat_interface(question):
71
  chat_history.append((question, response))
72
  return chat_history
73
 
74
-
75
  def upload_image(image):
76
  global current_image
77
  current_image = image.name
78
  extracted_image_path = extract_and_display_images(current_image)
79
  return "Image uploaded and processed successfully!", extracted_image_path
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Gradio UI
82
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
83
- gr.Markdown("# 🏥 AI-Powered Medical Image Analysis Chatbot")
84
  with gr.Row():
85
  with gr.Column(scale=1, min_width=200):
86
  chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
@@ -95,4 +139,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
95
  submit_button.click(chat_interface, question_input, chat_list)
96
  question_input.submit(chat_interface, question_input, chat_list)
97
 
98
- chat_ui.launch()
 
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
+ from datasets import load_dataset
8
+ from evaluate import load # For evaluation metrics
9
 
10
  # Model setup
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
12
  dtype = torch.float32
13
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
  proj_out_num = 256
 
16
  # Load model and tokenizer
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name_or_path,
19
+ torch_dtype=dtype,
20
+ device_map=device.type,
21
  trust_remote_code=True
22
  )
23
 
 
29
  trust_remote_code=True
30
  )
31
 
32
+ # Load the M3D-Cap dataset
33
+ dataset = load_dataset("GoodBaiBai88/M3D-Cap")
34
+
35
  # Chat history storage
36
  chat_history = []
37
  current_image = None
38
 
39
  def extract_and_display_images(image_path):
40
+ try:
41
+ npy_data = np.load(image_path)
42
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
43
+ npy_data = npy_data[0]
44
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
45
+ return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
46
+
47
+ fig, axes = plt.subplots(4, 8, figsize=(12, 6))
48
+ for i, ax in enumerate(axes.flat):
49
+ ax.imshow(npy_data[i], cmap='gray')
50
+ ax.axis('off')
51
+
52
+ image_output = "extracted_images.png"
53
+ plt.savefig(image_output, bbox_inches='tight')
54
+ plt.close()
55
+ return image_output
56
+ except Exception as e:
57
+ return f"Error processing image: {str(e)}"
58
 
59
  def process_image(question):
60
  global current_image
61
  if current_image is None:
62
  return "Please upload an image first."
63
 
64
+ try:
65
+ image_np = np.load(current_image)
66
+ image_tokens = "<im_patch>" * proj_out_num
67
+ input_txt = image_tokens + question
68
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
69
+
70
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
71
+ generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
72
+ generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
73
+ return generated_texts[0]
74
+ except Exception as e:
75
+ return f"Error generating response: {str(e)}"
76
 
77
  def chat_interface(question):
78
  global chat_history
 
80
  chat_history.append((question, response))
81
  return chat_history
82
 
 
83
  def upload_image(image):
84
  global current_image
85
  current_image = image.name
86
  extracted_image_path = extract_and_display_images(current_image)
87
  return "Image uploaded and processed successfully!", extracted_image_path
88
 
89
+ def test_model_with_dataset():
90
+ # Load evaluation metrics
91
+ bleu = load("bleu")
92
+ rouge = load("rouge")
93
+
94
+ # Initialize lists to store predictions and references
95
+ predictions = []
96
+ references = []
97
+
98
+ # Iterate over the dataset
99
+ for example in dataset['train']: # Use 'train', 'validation', or 'test' split
100
+ image_path = example['image'] # Assuming 'image' contains the path to the .npy file
101
+ question = example['caption'] # Assuming 'caption' contains the question or caption
102
+
103
+ # Upload the image
104
+ upload_image({"name": image_path})
105
+
106
+ # Get the model's response
107
+ response = process_image(question)
108
+
109
+ # Store predictions and references
110
+ predictions.append(response)
111
+ references.append(question)
112
+
113
+ # Print results for debugging
114
+ print(f"Question: {question}")
115
+ print(f"Model Response: {response}")
116
+ print("---")
117
+
118
+ # Compute evaluation metrics
119
+ bleu_score = bleu.compute(predictions=predictions, references=references)
120
+ rouge_score = rouge.compute(predictions=predictions, references=references)
121
+
122
+ print(f"BLEU Score: {bleu_score}")
123
+ print(f"ROUGE Score: {rouge_score}")
124
+
125
  # Gradio UI
126
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
127
+ gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
128
  with gr.Row():
129
  with gr.Column(scale=1, min_width=200):
130
  chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
 
139
  submit_button.click(chat_interface, question_input, chat_list)
140
  question_input.submit(chat_interface, question_input, chat_list)
141
 
142
+ # Uncomment to test the model with the dataset
143
+ # test_model_with_dataset()
144
+
145
+ chat_ui.launch()