Rohith1112 commited on
Commit
28f8c56
·
verified ·
1 Parent(s): 51a883e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -1,11 +1,13 @@
 
1
  import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
 
5
 
6
  # Model setup
7
  device = torch.device('cpu') # Use 'cuda' if GPU is available
8
- dtype = torch.float32 # Data type for model processing
9
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
  proj_out_num = 256
11
 
@@ -27,53 +29,70 @@ tokenizer = AutoTokenizer.from_pretrained(
27
 
28
  # Chat history storage
29
  chat_history = []
30
- current_image = None # To store the uploaded image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def process_image(question):
33
  global current_image
34
  if current_image is None:
35
  return "Please upload an image first."
36
 
37
- image_np = np.load(current_image) # Load the stored .npy image
38
  image_tokens = "<im_patch>" * proj_out_num
39
  input_txt = image_tokens + question
40
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
41
 
42
- # Prepare image for model
43
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
44
-
45
- # Generate response
46
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
47
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
48
  return generated_texts[0]
49
 
50
- # Function to update chat
51
  def chat_interface(question):
52
  global chat_history
53
  response = process_image(question)
54
  chat_history.append((question, response))
55
  return chat_history
56
 
57
- # Function to handle image upload
58
  def upload_image(image):
59
  global current_image
60
  current_image = image.name
61
- return "Image uploaded successfully!"
 
62
 
63
  # Gradio UI
64
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
65
- gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
66
  with gr.Row():
67
  with gr.Column(scale=1, min_width=200):
68
  chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
69
  with gr.Column(scale=4):
70
  uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
71
  upload_status = gr.Textbox(label="Status", interactive=False)
 
72
  question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
73
  submit_button = gr.Button("Send")
74
 
75
- uploaded_image.upload(upload_image, uploaded_image, upload_status)
76
  submit_button.click(chat_interface, question_input, chat_list)
77
  question_input.submit(chat_interface, question_input, chat_list)
78
 
79
- chat_ui.launch()
 
1
+ import os
2
  import numpy as np
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
 
8
  # Model setup
9
  device = torch.device('cpu') # Use 'cuda' if GPU is available
10
+ dtype = torch.float32
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
  proj_out_num = 256
13
 
 
29
 
30
  # Chat history storage
31
  chat_history = []
32
+ current_image = None
33
+
34
+ def extract_and_display_images(image_path):
35
+ npy_data = np.load(image_path)
36
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
+ npy_data = npy_data[0]
38
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
+ return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
+
41
+ fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
+ for i, ax in enumerate(axes.flat):
43
+ ax.imshow(npy_data[i], cmap='gray')
44
+ ax.axis('off')
45
+
46
+ image_output = "extracted_images.png"
47
+ plt.savefig(image_output, bbox_inches='tight')
48
+ plt.close()
49
+ return image_output
50
+
51
 
52
  def process_image(question):
53
  global current_image
54
  if current_image is None:
55
  return "Please upload an image first."
56
 
57
+ image_np = np.load(current_image)
58
  image_tokens = "<im_patch>" * proj_out_num
59
  input_txt = image_tokens + question
60
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
 
 
62
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
 
 
63
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
  return generated_texts[0]
66
 
67
+
68
  def chat_interface(question):
69
  global chat_history
70
  response = process_image(question)
71
  chat_history.append((question, response))
72
  return chat_history
73
 
74
+
75
  def upload_image(image):
76
  global current_image
77
  current_image = image.name
78
+ extracted_image_path = extract_and_display_images(current_image)
79
+ return "Image uploaded and processed successfully!", extracted_image_path
80
 
81
  # Gradio UI
82
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
83
+ gr.Markdown("# 🏥 AI-Powered Medical Image Analysis Chatbot")
84
  with gr.Row():
85
  with gr.Column(scale=1, min_width=200):
86
  chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
87
  with gr.Column(scale=4):
88
  uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
89
  upload_status = gr.Textbox(label="Status", interactive=False)
90
+ extracted_image = gr.Image(label="Extracted Images")
91
  question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
92
  submit_button = gr.Button("Send")
93
 
94
+ uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
95
  submit_button.click(chat_interface, question_input, chat_list)
96
  question_input.submit(chat_interface, question_input, chat_list)
97
 
98
+ chat_ui.launch()