Rohith1112 commited on
Commit
7e44175
·
verified ·
1 Parent(s): e19a001
Files changed (1) hide show
  1. app.py +51 -42
app.py CHANGED
@@ -3,68 +3,77 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
 
6
- # Set device & model details
7
- device = torch.device('cpu')
8
- dtype = torch.float32
9
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
- proj_out_num = 256 # Number of projection outputs required
11
 
12
- # Load model & tokenizer
13
  model = AutoModelForCausalLM.from_pretrained(
14
- model_name_or_path, torch_dtype=dtype, device_map='cpu', trust_remote_code=True
 
 
 
15
  )
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(
18
- model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True
 
 
 
 
19
  )
20
 
21
- # Image placeholder (to maintain session context)
22
- uploaded_image = None
 
23
 
24
- def process_image(question, history):
25
- global uploaded_image
26
- if uploaded_image is None:
27
- return "⚠️ Please upload an image first!"
28
-
29
- # Load the .npy image
30
- image_np = np.load(uploaded_image)
31
  image_tokens = "<im_patch>" * proj_out_num
32
  input_txt = image_tokens + question
33
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
34
-
35
- # Convert image to tensor
36
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
37
-
38
  # Generate response
39
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
40
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
41
-
42
  return generated_texts[0]
43
 
 
 
 
 
 
 
 
 
44
  def upload_image(image):
45
- """ Stores the uploaded image path to be used in chat """
46
- global uploaded_image
47
- uploaded_image = image.name
48
- return f"✅ Image uploaded successfully: {image.name}"
49
 
50
- # Chat Interface with File Upload
51
- with gr.Blocks(theme="soft") as chat_ui:
 
52
  with gr.Row():
53
- with gr.Column(scale=2):
54
- gr.Markdown("# 🏥 Medical Image Chatbot")
55
- uploaded_file = gr.File(label="Upload .npy Image", type="filepath")
56
- upload_button = gr.Button("Upload")
57
- status = gr.Markdown("")
58
- chat = gr.Chatbot(height=400)
59
-
60
- with gr.Column(scale=3):
61
- input_box = gr.Textbox(placeholder="Ask something about the image...")
62
- send_button = gr.Button("Send ✉️")
63
-
64
- # Handle image upload
65
- upload_button.click(upload_image, inputs=[uploaded_file], outputs=[status])
66
-
67
- # Handle chat interaction
68
- send_button.click(process_image, inputs=[input_box, chat], outputs=[chat])
69
 
70
  chat_ui.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
 
6
+ # Model setup
7
+ device = torch.device('cpu') # Use 'cuda' if GPU is available
8
+ dtype = torch.float32 # Data type for model processing
9
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
+ proj_out_num = 256
11
 
12
+ # Load model and tokenizer
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ model_name_or_path,
15
+ torch_dtype=torch.float32,
16
+ device_map='cpu',
17
+ trust_remote_code=True
18
  )
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(
21
+ model_name_or_path,
22
+ model_max_length=512,
23
+ padding_side="right",
24
+ use_fast=False,
25
+ trust_remote_code=True
26
  )
27
 
28
+ # Chat history storage
29
+ chat_history = []
30
+ current_image = None # To store the uploaded image
31
 
32
+ def process_image(question):
33
+ global current_image
34
+ if current_image is None:
35
+ return "Please upload an image first."
36
+
37
+ image_np = np.load(current_image) # Load the stored .npy image
 
38
  image_tokens = "<im_patch>" * proj_out_num
39
  input_txt = image_tokens + question
40
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
41
+
42
+ # Prepare image for model
43
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
44
+
45
  # Generate response
46
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
47
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
 
48
  return generated_texts[0]
49
 
50
+ # Function to update chat
51
+ def chat_interface(question):
52
+ global chat_history
53
+ response = process_image(question)
54
+ chat_history.append((question, response))
55
+ return chat_history
56
+
57
+ # Function to handle image upload
58
  def upload_image(image):
59
+ global current_image
60
+ current_image = image.name
61
+ return "Image uploaded successfully!"
 
62
 
63
+ # Gradio UI
64
+ with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
65
+ gr.Markdown("# 🏥 Medical Image Analysis Chatbot")
66
  with gr.Row():
67
+ with gr.Column(scale=1, min_width=200):
68
+ chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
69
+ with gr.Column(scale=4):
70
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
71
+ upload_status = gr.Textbox(label="Status", interactive=False)
72
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
73
+ submit_button = gr.Button("Send")
74
+
75
+ uploaded_image.upload(upload_image, uploaded_image, upload_status)
76
+ submit_button.click(chat_interface, question_input, chat_list)
77
+ question_input.submit(chat_interface, question_input, chat_list)
 
 
 
 
 
78
 
79
  chat_ui.launch()