Rohith1112 commited on
Commit
f94d228
Β·
verified Β·
1 Parent(s): 287acbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -62
app.py CHANGED
@@ -6,10 +6,8 @@ import gradio as gr
6
  import matplotlib.pyplot as plt
7
 
8
  # Model setup
9
- device = torch.device('cpu') # Use 'cuda' if GPU is available
10
- dtype = torch.float32
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
- proj_out_num = 256
13
 
14
  # Load model and tokenizer
15
  model = AutoModelForCausalLM.from_pretrained(
@@ -18,7 +16,6 @@ model = AutoModelForCausalLM.from_pretrained(
18
  device_map='cpu',
19
  trust_remote_code=True
20
  )
21
-
22
  tokenizer = AutoTokenizer.from_pretrained(
23
  model_name_or_path,
24
  model_max_length=512,
@@ -27,7 +24,7 @@ tokenizer = AutoTokenizer.from_pretrained(
27
  trust_remote_code=True
28
  )
29
 
30
- # Chat history storage
31
  chat_history = []
32
  current_image = None
33
 
@@ -36,86 +33,115 @@ def extract_and_display_images(image_path):
36
  if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
  npy_data = npy_data[0]
38
  elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
- return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
 
41
  fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
  for i, ax in enumerate(axes.flat):
43
  ax.imshow(npy_data[i], cmap='gray')
44
  ax.axis('off')
45
 
46
- image_output = "extracted_images.png"
47
- plt.savefig(image_output, bbox_inches='tight')
48
  plt.close()
49
- return image_output
50
 
51
- def process_image(question):
52
  global current_image
53
  if current_image is None:
54
  return "Please upload an image first."
55
 
56
  image_np = np.load(current_image)
57
- image_tokens = "<im_patch>" * proj_out_num
58
  input_txt = image_tokens + question
59
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
60
 
61
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
62
- generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
63
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
64
  return generated_texts[0]
65
 
66
- def chat_interface(question):
67
- global chat_history
68
- response = process_image(question)
69
- chat_history.append((question, response))
70
- return chat_history
71
-
72
  def upload_image(image):
73
  global current_image
74
  current_image = image.name
75
- extracted_image_path = extract_and_display_images(current_image)
76
- return "Image uploaded and processed successfully!", extracted_image_path
77
 
78
- # Gradio UI with Animation and Styling
 
 
 
 
 
 
79
  with gr.Blocks(css="""
80
- .gr-chatbot-container {
81
- transition: all 0.3s ease-in-out;
82
- opacity: 0;
83
- }
84
- .gr-chatbot-container.show {
85
- opacity: 1;
86
- }
87
- .gr-chatbot-message {
88
- transition: all 0.3s ease-in-out;
89
- transform: translateX(-10px);
90
- }
91
- .gr-chatbot-message.show {
92
- transform: translateX(0);
93
- }
94
- .gr-image-container img {
95
- transition: transform 0.5s ease;
96
- }
97
- .gr-image-container img:hover {
98
- transform: scale(1.1);
99
- }
100
- .gr-button:hover {
101
- background-color: #007bff;
102
- transform: scale(1.1);
103
- transition: all 0.3s ease-in-out;
104
- }
105
- """) as chat_ui:
106
- gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  with gr.Row():
108
- with gr.Column(scale=1, min_width=200):
109
- chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
110
- with gr.Column(scale=4):
111
- uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
112
  upload_status = gr.Textbox(label="Status", interactive=False)
113
- extracted_image = gr.Image(label="Extracted Images")
114
- question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
115
- submit_button = gr.Button("Send")
116
-
117
- uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
118
- submit_button.click(chat_interface, question_input, chat_list)
119
- question_input.submit(chat_interface, question_input, chat_list)
120
 
121
- chat_ui.launch()
 
6
  import matplotlib.pyplot as plt
7
 
8
  # Model setup
9
+ device = torch.device('cpu') # Use 'cuda' if available
 
10
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
 
11
 
12
  # Load model and tokenizer
13
  model = AutoModelForCausalLM.from_pretrained(
 
16
  device_map='cpu',
17
  trust_remote_code=True
18
  )
 
19
  tokenizer = AutoTokenizer.from_pretrained(
20
  model_name_or_path,
21
  model_max_length=512,
 
24
  trust_remote_code=True
25
  )
26
 
27
+ # Storage
28
  chat_history = []
29
  current_image = None
30
 
 
33
  if npy_data.ndim == 4 and npy_data.shape[1] == 32:
34
  npy_data = npy_data[0]
35
  elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
36
+ return "Invalid .npy format. Expected (1, 32, 256, 256) or (32, 256, 256)."
37
 
38
  fig, axes = plt.subplots(4, 8, figsize=(12, 6))
39
  for i, ax in enumerate(axes.flat):
40
  ax.imshow(npy_data[i], cmap='gray')
41
  ax.axis('off')
42
 
43
+ output_path = "converted_image_preview.png"
44
+ plt.savefig(output_path, bbox_inches='tight')
45
  plt.close()
46
+ return output_path
47
 
48
+ def process_question(question):
49
  global current_image
50
  if current_image is None:
51
  return "Please upload an image first."
52
 
53
  image_np = np.load(current_image)
54
+ image_tokens = "<im_patch>" * 256
55
  input_txt = image_tokens + question
56
+ input_ids = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
57
 
58
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=torch.float32, device=device)
59
+ generation = model.generate(image_pt, input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
60
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
61
  return generated_texts[0]
62
 
 
 
 
 
 
 
63
  def upload_image(image):
64
  global current_image
65
  current_image = image.name
66
+ preview_path = extract_and_display_images(current_image)
67
+ return "Image uploaded successfully!", preview_path
68
 
69
+ def chat_with_model(user_message):
70
+ global chat_history
71
+ response = process_question(user_message)
72
+ chat_history.append((user_message, response))
73
+ return chat_history
74
+
75
+ # UI Design
76
  with gr.Blocks(css="""
77
+ body {
78
+ background: linear-gradient(135deg, #00b4db, #0083b0);
79
+ font-family: 'Poppins', sans-serif;
80
+ color: white;
81
+ }
82
+
83
+ .gr-box {
84
+ border-radius: 16px;
85
+ background: rgba(255,255,255,0.1);
86
+ padding: 20px;
87
+ backdrop-filter: blur(10px);
88
+ box-shadow: 0 8px 32px 0 rgba( 31, 38, 135, 0.37 );
89
+ }
90
+
91
+ h1 {
92
+ text-align: center;
93
+ font-size: 2.5em;
94
+ margin-bottom: 20px;
95
+ color: #ffffff;
96
+ }
97
+
98
+ .gr-chatbot-container {
99
+ overflow-y: auto;
100
+ max-height: 500px;
101
+ }
102
+
103
+ .gr-chatbot-message {
104
+ margin-bottom: 15px;
105
+ padding: 10px;
106
+ border-radius: 10px;
107
+ background: rgba(0,0,0,0.3);
108
+ transition: 0.3s;
109
+ }
110
+
111
+ .gr-chatbot-message:hover {
112
+ transform: scale(1.02);
113
+ background: rgba(255,255,255,0.1);
114
+ }
115
+
116
+ .gr-button {
117
+ background-color: #ff7e5f;
118
+ border: none;
119
+ padding: 10px 20px;
120
+ border-radius: 20px;
121
+ color: white;
122
+ font-weight: bold;
123
+ transition: 0.3s;
124
+ }
125
+
126
+ .gr-button:hover {
127
+ background-color: #feb47b;
128
+ transform: scale(1.05);
129
+ }
130
+ """) as app:
131
+ gr.Markdown("# πŸš€ AI Powered Medical Image Analysis System πŸš€")
132
+
133
  with gr.Row():
134
+ with gr.Column(scale=1, min_width=250):
135
+ chat_history_box = gr.Chatbot(value=[], label="πŸ—‚ Chat History")
136
+ with gr.Column(scale=2):
137
+ uploaded_image = gr.File(label="πŸ“€ Upload NPY Image", type="filepath")
138
  upload_status = gr.Textbox(label="Status", interactive=False)
139
+ preview_image = gr.Image(label="πŸ–Ό Image Preview")
140
+ user_input = gr.Textbox(label="πŸ’¬ Ask a question about the image...")
141
+ send_button = gr.Button("πŸ“¨ Send")
142
+
143
+ uploaded_image.upload(upload_image, uploaded_image, [upload_status, preview_image])
144
+ send_button.click(chat_with_model, user_input, chat_history_box)
145
+ user_input.submit(chat_with_model, user_input, chat_history_box)
146
 
147
+ app.launch()