Rohith1112 commited on
Commit
e19a001
·
verified ·
1 Parent(s): cecc48c
Files changed (1) hide show
  1. app.py +46 -37
app.py CHANGED
@@ -1,61 +1,70 @@
1
  import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import simple_slice_viewer as ssv
5
- import SimpleITK as sikt
6
  import gradio as gr
7
 
8
- device = torch.device('cpu') # Set to 'cuda' if using a GPU
9
- dtype = torch.float32 # Data type for model processing
10
-
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
- proj_out_num = 256 # Number of projection outputs required for the image
13
 
14
- # Load model and tokenizer
15
  model = AutoModelForCausalLM.from_pretrained(
16
- model_name_or_path,
17
- torch_dtype=torch.float32,
18
- device_map='cpu',
19
- trust_remote_code=True
20
  )
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(
23
- model_name_or_path,
24
- model_max_length=512,
25
- padding_side="right",
26
- use_fast=False,
27
- trust_remote_code=True
28
  )
29
 
30
- def process_image(image_path, question):
31
- # Load the image
32
- image_np = np.load(image_path) # Load the .npy image
 
 
 
 
 
 
 
33
  image_tokens = "<im_patch>" * proj_out_num
34
  input_txt = image_tokens + question
35
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
36
 
37
- # Prepare image for model
38
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
39
 
40
- # Generate model response
41
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
42
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
43
 
44
  return generated_texts[0]
45
 
46
- # Gradio Interface
47
- def gradio_interface(image, question):
48
- response = process_image(image.name, question)
49
- return response
50
-
51
- # Gradio App
52
- gr.Interface(
53
- fn=gradio_interface,
54
- inputs=[
55
- gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
56
- gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
57
- ],
58
- outputs=gr.Textbox(label="Model Response"),
59
- title="Medical Image Analysis",
60
- description="Upload a .npy image and ask a question to analyze it using the model."
61
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
4
  import gradio as gr
5
 
6
+ # Set device & model details
7
+ device = torch.device('cpu')
8
+ dtype = torch.float32
9
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
+ proj_out_num = 256 # Number of projection outputs required
11
 
12
+ # Load model & tokenizer
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ model_name_or_path, torch_dtype=dtype, device_map='cpu', trust_remote_code=True
 
 
 
15
  )
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(
18
+ model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True
 
 
 
 
19
  )
20
 
21
+ # Image placeholder (to maintain session context)
22
+ uploaded_image = None
23
+
24
+ def process_image(question, history):
25
+ global uploaded_image
26
+ if uploaded_image is None:
27
+ return "⚠️ Please upload an image first!"
28
+
29
+ # Load the .npy image
30
+ image_np = np.load(uploaded_image)
31
  image_tokens = "<im_patch>" * proj_out_num
32
  input_txt = image_tokens + question
33
  input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
34
 
35
+ # Convert image to tensor
36
  image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
37
 
38
+ # Generate response
39
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
40
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
41
 
42
  return generated_texts[0]
43
 
44
+ def upload_image(image):
45
+ """ Stores the uploaded image path to be used in chat """
46
+ global uploaded_image
47
+ uploaded_image = image.name
48
+ return f"✅ Image uploaded successfully: {image.name}"
49
+
50
+ # Chat Interface with File Upload
51
+ with gr.Blocks(theme="soft") as chat_ui:
52
+ with gr.Row():
53
+ with gr.Column(scale=2):
54
+ gr.Markdown("# 🏥 Medical Image Chatbot")
55
+ uploaded_file = gr.File(label="Upload .npy Image", type="filepath")
56
+ upload_button = gr.Button("Upload")
57
+ status = gr.Markdown("")
58
+ chat = gr.Chatbot(height=400)
59
+
60
+ with gr.Column(scale=3):
61
+ input_box = gr.Textbox(placeholder="Ask something about the image...")
62
+ send_button = gr.Button("Send ✉️")
63
+
64
+ # Handle image upload
65
+ upload_button.click(upload_image, inputs=[uploaded_file], outputs=[status])
66
+
67
+ # Handle chat interaction
68
+ send_button.click(process_image, inputs=[input_box, chat], outputs=[chat])
69
+
70
+ chat_ui.launch()