Rohith1112 commited on
Commit
51a883e
·
verified ·
1 Parent(s): 4836e10
Files changed (1) hide show
  1. app.py +23 -31
app.py CHANGED
@@ -2,14 +2,11 @@ import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
- import matplotlib.pyplot as plt
6
- from PIL import Image
7
- import io
8
 
9
  # Model setup
10
- device = torch.device("cpu") # Use 'cuda' if GPU is available
11
- dtype = torch.float32
12
- model_name_or_path = "GoodBaiBai88/M3D-LaMed-Phi-3-4B"
13
  proj_out_num = 256
14
 
15
  # Load model and tokenizer
@@ -30,39 +27,31 @@ tokenizer = AutoTokenizer.from_pretrained(
30
 
31
  # Chat history storage
32
  chat_history = []
33
- current_image = None
34
-
35
- # Convert .npy to JPEG
36
- def npy_to_jpeg(npy_file):
37
- image_np = np.load(npy_file)
38
- image = Image.fromarray((image_np * 255).astype(np.uint8)) # Normalize and convert to uint8
39
- img_bytes = io.BytesIO()
40
- image.save(img_bytes, format='JPEG')
41
- return img_bytes.getvalue()
42
 
43
  def process_image(question):
44
  global current_image
45
  if current_image is None:
46
- return "Please upload an image first.", None
47
 
48
- image_bytes = npy_to_jpeg(current_image) # Convert image to JPEG
49
  image_tokens = "<im_patch>" * proj_out_num
50
  input_txt = image_tokens + question
51
- input_id = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device=device)
 
 
 
52
 
53
  # Generate response
54
- generation = model.generate(input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
55
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
56
- return generated_texts[0], image_bytes
57
 
58
  # Function to update chat
59
  def chat_interface(question):
60
  global chat_history
61
- response, image_bytes = process_image(question)
62
- if image_bytes:
63
- chat_history.append((question, response, image_bytes))
64
- else:
65
- chat_history.append((question, response, None))
66
  return chat_history
67
 
68
  # Function to handle image upload
@@ -74,14 +63,17 @@ def upload_image(image):
74
  # Gradio UI
75
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
76
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
77
- chat_list = gr.Chatbot(label="Chat History", elem_id="chat-history")
78
- uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
79
- upload_status = gr.Textbox(label="Status", interactive=False)
80
- question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
81
- submit_button = gr.Button("Send")
 
 
 
82
 
83
  uploaded_image.upload(upload_image, uploaded_image, upload_status)
84
  submit_button.click(chat_interface, question_input, chat_list)
85
  question_input.submit(chat_interface, question_input, chat_list)
86
 
87
- chat_ui.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
 
 
 
5
 
6
  # Model setup
7
+ device = torch.device('cpu') # Use 'cuda' if GPU is available
8
+ dtype = torch.float32 # Data type for model processing
9
+ model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
  proj_out_num = 256
11
 
12
  # Load model and tokenizer
 
27
 
28
  # Chat history storage
29
  chat_history = []
30
+ current_image = None # To store the uploaded image
 
 
 
 
 
 
 
 
31
 
32
  def process_image(question):
33
  global current_image
34
  if current_image is None:
35
+ return "Please upload an image first."
36
 
37
+ image_np = np.load(current_image) # Load the stored .npy image
38
  image_tokens = "<im_patch>" * proj_out_num
39
  input_txt = image_tokens + question
40
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
41
+
42
+ # Prepare image for model
43
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
44
 
45
  # Generate response
46
+ generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
47
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
48
+ return generated_texts[0]
49
 
50
  # Function to update chat
51
  def chat_interface(question):
52
  global chat_history
53
+ response = process_image(question)
54
+ chat_history.append((question, response))
 
 
 
55
  return chat_history
56
 
57
  # Function to handle image upload
 
63
  # Gradio UI
64
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
65
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
66
+ with gr.Row():
67
+ with gr.Column(scale=1, min_width=200):
68
+ chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
69
+ with gr.Column(scale=4):
70
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
71
+ upload_status = gr.Textbox(label="Status", interactive=False)
72
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
73
+ submit_button = gr.Button("Send")
74
 
75
  uploaded_image.upload(upload_image, uploaded_image, upload_status)
76
  submit_button.click(chat_interface, question_input, chat_list)
77
  question_input.submit(chat_interface, question_input, chat_list)
78
 
79
+ chat_ui.launch()