Rohith1112 commited on
Commit
9c21d4a
·
verified ·
1 Parent(s): 7529a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -115
app.py CHANGED
@@ -2,113 +2,55 @@ import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
- from PIL import Image
6
- import io
7
 
8
  # Model setup
9
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
10
- dtype = torch.float32 # Adjust based on model requirements
11
- model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
  proj_out_num = 256
13
 
14
  # Load model and tokenizer
15
- try:
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name_or_path,
18
- torch_dtype=dtype,
19
- device_map=device,
20
- trust_remote_code=True
21
- )
22
- print("Model loaded successfully!")
23
- except Exception as e:
24
- print(f"Error loading model: {e}")
25
- raise
26
 
27
- try:
28
- tokenizer = AutoTokenizer.from_pretrained(
29
- model_name_or_path,
30
- model_max_length=512,
31
- padding_side="right",
32
- use_fast=False,
33
- trust_remote_code=True
34
- )
35
- print("Tokenizer loaded successfully!")
36
- except Exception as e:
37
- print(f"Error loading tokenizer: {e}")
38
- raise
39
 
40
  # Chat history storage
41
  chat_history = []
42
- current_image = None # To store the uploaded image
43
 
44
- # Function to convert .npy to JPEG
45
- def npy_to_jpeg(image_np):
46
- # Handle multi-dimensional .npy files
47
- if image_np.ndim == 4: # If batch dimension is present (e.g., (1, 256, 256, 3))
48
- image_np = image_np.squeeze(0) # Remove batch dimension
49
- elif image_np.ndim == 2: # Grayscale image (e.g., (256, 256))
50
- image_np = np.expand_dims(image_np, axis=-1) # Add channel dimension
51
-
52
- # Normalize and convert to uint8
53
- image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) * 255
54
- image_np = image_np.astype(np.uint8)
55
-
56
- # Convert to PIL Image
57
- if image_np.shape[-1] == 1: # Grayscale
58
- image_np = image_np.squeeze()
59
- image = Image.fromarray(image_np, mode='L')
60
- else: # RGB
61
- image = Image.fromarray(image_np, mode='RGB')
62
-
63
- # Save to bytes
64
- buf = io.BytesIO()
65
- image.save(buf, format='JPEG')
66
- buf.seek(0)
67
- return buf
68
-
69
- # Function to process image and generate response
70
  def process_image(question):
71
  global current_image
72
  if current_image is None:
73
- return "Please upload an image first.", None
 
 
 
 
 
 
 
 
74
 
75
- try:
76
- # Load the stored .npy image
77
- image_np = np.load(current_image)
78
- if image_np.shape[-1] != proj_out_num: # Ensure image matches expected dimensions
79
- return f"Invalid image dimensions. Expected {proj_out_num} patches, got {image_np.shape[-1]}.", None
80
-
81
- # Convert .npy to JPEG
82
- jpeg_image = npy_to_jpeg(image_np)
83
-
84
- # Prepare image tokens and input text
85
- image_tokens = "<im_patch>" * proj_out_num
86
- input_txt = image_tokens + question
87
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device, dtype=torch.long) # Ensure input_id is LongTensor
88
-
89
- # Prepare image tensor
90
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
91
-
92
- # Generate response
93
- generation = model.generate(
94
- inputs=image_pt,
95
- input_ids=input_id,
96
- max_new_tokens=256,
97
- do_sample=True,
98
- top_p=0.9,
99
- temperature=1.0
100
- )
101
- generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
102
- return generated_texts[0], jpeg_image
103
- except Exception as e:
104
- return f"Error processing image: {str(e)}", None
105
 
106
  # Function to update chat
107
  def chat_interface(question):
108
  global chat_history
109
- response, jpeg_image = process_image(question)
110
- if jpeg_image:
111
- chat_history.append((None, (jpeg_image,))) # Display image in chat
112
  chat_history.append((question, response))
113
  return chat_history
114
 
@@ -118,37 +60,20 @@ def upload_image(image):
118
  current_image = image.name
119
  return "Image uploaded successfully!"
120
 
121
- # Function to clear chat history
122
- def clear_chat():
123
- global chat_history, current_image
124
- chat_history = []
125
- current_image = None
126
- return [], "Chat history cleared."
127
-
128
  # Gradio UI
129
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
130
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
131
-
132
- # File upload section
133
- with gr.Row():
134
- upload_button = gr.UploadButton(label="📁 Upload .npy Image", file_types=[".npy"])
135
- upload_status = gr.Textbox(label="Status", interactive=False)
136
-
137
- # Chat interface
138
- with gr.Row():
139
- chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
140
-
141
- # Question input and buttons
142
  with gr.Row():
143
- question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...", lines=2)
144
- submit_button = gr.Button("Send")
145
- clear_button = gr.Button("Clear Chat")
 
 
 
 
146
 
147
- # Event handlers
148
- upload_button.upload(upload_image, upload_button, upload_status)
149
  submit_button.click(chat_interface, question_input, chat_list)
150
  question_input.submit(chat_interface, question_input, chat_list)
151
- clear_button.click(clear_chat, outputs=[chat_list, upload_status])
152
 
153
- # Launch the app
154
- chat_ui.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
 
 
5
 
6
  # Model setup
7
+ device = torch.device("cpu") # Use 'cuda' if GPU is available
8
+ dtype = torch.float32
9
+ model_name_or_path = "GoodBaiBai88/M3D-LaMed-Phi-3-4B"
10
  proj_out_num = 256
11
 
12
  # Load model and tokenizer
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name_or_path,
15
+ torch_dtype=torch.float32,
16
+ device_map='cpu',
17
+ trust_remote_code=True
18
+ )
 
 
 
 
 
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ model_name_or_path,
22
+ model_max_length=512,
23
+ padding_side="right",
24
+ use_fast=False,
25
+ trust_remote_code=True
26
+ )
 
 
 
 
 
27
 
28
  # Chat history storage
29
  chat_history = []
30
+ current_image = None
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def process_image(question):
33
  global current_image
34
  if current_image is None:
35
+ return "Please upload an image first."
36
+
37
+ image_np = np.load(current_image) # Load the stored .npy image
38
+ image_tokens = "<im_patch>" * proj_out_num
39
+ input_txt = image_tokens + question
40
+ input_id = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device=device)
41
+
42
+ # Prepare image for model
43
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
44
 
45
+ # Generate response
46
+ generation = model.generate(input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
47
+ generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
48
+ return generated_texts[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Function to update chat
51
  def chat_interface(question):
52
  global chat_history
53
+ response = process_image(question)
 
 
54
  chat_history.append((question, response))
55
  return chat_history
56
 
 
60
  current_image = image.name
61
  return "Image uploaded successfully!"
62
 
 
 
 
 
 
 
 
63
  # Gradio UI
64
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
65
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Row():
67
+ with gr.Column(scale=1, min_width=200):
68
+ chat_list = gr.Chatbot(label="Chat History", elem_id="chat-history")
69
+ with gr.Column(scale=4):
70
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
71
+ upload_status = gr.Textbox(label="Status", interactive=False)
72
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
73
+ submit_button = gr.Button("Send")
74
 
75
+ uploaded_image.upload(upload_image, uploaded_image, upload_status)
 
76
  submit_button.click(chat_interface, question_input, chat_list)
77
  question_input.submit(chat_interface, question_input, chat_list)
 
78
 
79
+ chat_ui.launch()