Rohith1112 commited on
Commit
4836e10
·
verified ·
1 Parent(s): 9c21d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -2,6 +2,9 @@ import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
 
 
 
5
 
6
  # Model setup
7
  device = torch.device("cpu") # Use 'cuda' if GPU is available
@@ -29,29 +32,37 @@ tokenizer = AutoTokenizer.from_pretrained(
29
  chat_history = []
30
  current_image = None
31
 
 
 
 
 
 
 
 
 
32
  def process_image(question):
33
  global current_image
34
  if current_image is None:
35
- return "Please upload an image first."
36
 
37
- image_np = np.load(current_image) # Load the stored .npy image
38
  image_tokens = "<im_patch>" * proj_out_num
39
  input_txt = image_tokens + question
40
  input_id = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device=device)
41
 
42
- # Prepare image for model
43
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
44
-
45
  # Generate response
46
  generation = model.generate(input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
47
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
48
- return generated_texts[0]
49
 
50
  # Function to update chat
51
  def chat_interface(question):
52
  global chat_history
53
- response = process_image(question)
54
- chat_history.append((question, response))
 
 
 
55
  return chat_history
56
 
57
  # Function to handle image upload
@@ -63,14 +74,11 @@ def upload_image(image):
63
  # Gradio UI
64
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
65
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
66
- with gr.Row():
67
- with gr.Column(scale=1, min_width=200):
68
- chat_list = gr.Chatbot(label="Chat History", elem_id="chat-history")
69
- with gr.Column(scale=4):
70
- uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
71
- upload_status = gr.Textbox(label="Status", interactive=False)
72
- question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
73
- submit_button = gr.Button("Send")
74
 
75
  uploaded_image.upload(upload_image, uploaded_image, upload_status)
76
  submit_button.click(chat_interface, question_input, chat_list)
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import io
8
 
9
  # Model setup
10
  device = torch.device("cpu") # Use 'cuda' if GPU is available
 
32
  chat_history = []
33
  current_image = None
34
 
35
+ # Convert .npy to JPEG
36
+ def npy_to_jpeg(npy_file):
37
+ image_np = np.load(npy_file)
38
+ image = Image.fromarray((image_np * 255).astype(np.uint8)) # Normalize and convert to uint8
39
+ img_bytes = io.BytesIO()
40
+ image.save(img_bytes, format='JPEG')
41
+ return img_bytes.getvalue()
42
+
43
  def process_image(question):
44
  global current_image
45
  if current_image is None:
46
+ return "Please upload an image first.", None
47
 
48
+ image_bytes = npy_to_jpeg(current_image) # Convert image to JPEG
49
  image_tokens = "<im_patch>" * proj_out_num
50
  input_txt = image_tokens + question
51
  input_id = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device=device)
52
 
 
 
 
53
  # Generate response
54
  generation = model.generate(input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
55
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
56
+ return generated_texts[0], image_bytes
57
 
58
  # Function to update chat
59
  def chat_interface(question):
60
  global chat_history
61
+ response, image_bytes = process_image(question)
62
+ if image_bytes:
63
+ chat_history.append((question, response, image_bytes))
64
+ else:
65
+ chat_history.append((question, response, None))
66
  return chat_history
67
 
68
  # Function to handle image upload
 
74
  # Gradio UI
75
  with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
76
  gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
77
+ chat_list = gr.Chatbot(label="Chat History", elem_id="chat-history")
78
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
79
+ upload_status = gr.Textbox(label="Status", interactive=False)
80
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
81
+ submit_button = gr.Button("Send")
 
 
 
82
 
83
  uploaded_image.upload(upload_image, uploaded_image, upload_status)
84
  submit_button.click(chat_interface, question_input, chat_list)