MRI-Image / app.py
TharunSiva's picture
application file
8413e92 verified
import gradio as gr
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.backend as K
from keras.preprocessing import image
from ResUNet import *
from eff import *
from vit import *
# Define the image transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
])
examples1 = [
f"examples/Eff_ViT/Classification_{i}.jpg" for i in range(0, 4)
]
def classification(image):
input_tensor = transform(image).unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
input_batch = input_tensor
# Perform inference
with torch.no_grad():
output1 = efficientnet_model(input_batch).to(CFG.DEVICE)
output2 = efficientnet_model(input_batch).to(CFG.DEVICE)
output3 = vit_model(input_batch).to(CFG.DEVICE)
# You can now use the 'output' tensor as needed (e.g., get predictions)
# print(output)
res1 = torch.softmax(output1, dim=1)
res2 = torch.softmax(output2, dim=1)
res3 = torch.softmax(output3, dim=1)
probs1 = {class_names[i]: float(res1[0][i]) for i in range(len(class_names))}
probs2 = {class_names[i]: float(res2[0][i]) for i in range(len(class_names))}
probs3 = {class_names[i]: float(res3[0][i]) for i in range(len(class_names))}
return probs1, probs2, probs3
classify = gr.Interface(
fn=classification,
inputs=[
gr.Image(label="Image"),
# gr.Radio(["EfficientNetB3", "EfficientNetV2", "ViT"], value="ViT")
],
outputs=[
gr.Label(num_top_classes = 3, label = "EfficientNet-B3"),
gr.Label(num_top_classes = 3, label = "EfficientNet-V2"),
gr.Label(num_top_classes = 3, label = "ViT"),
],
examples=examples1,
cache_examples=True
)
# ---------------------------------------------------------
seg_model = load_model()
seg_model.load_weights("ResUNet-segModel-weights.hdf5")
examples2 = [
f"examples/ResUNet/{i}.jpg" for i in range(5)
]
def detection(img):
org_img = img
img = img *1./255.
#reshaping
img = cv2.resize(img, (256,256))
# converting img into array
img = np.array(img, dtype=np.float64)
#reshaping the image from 256,256,3 to 1,256,256,3
img = np.reshape(img, (1,256,256,3))
#Creating a empty array of shape 1,256,256,1
X = np.empty((1,256,256,3))
# standardising the image
img -= img.mean()
img /= img.std()
#converting the shape of image from 256,256,3 to 1,256,256,3
X[0,] = img
#make prediction of mask
predict = seg_model.predict(X)
pred = np.array(predict[0]).squeeze().round()
img_ = cv2.resize(org_img, (256,256))
img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB)
img_[pred==1] = (0,255,150)
plt.imshow(img_)
plt.axis("off")
image_path = "plot.png"
plt.savefig(image_path)
return gr.update(value=image_path, visible=True)
detect = gr.Interface(
fn=detection,
inputs=[
gr.Image(label="Image")
],
outputs=[
gr.Image(label="Output")
],
examples=examples2,
cache_examples=True
)
# ##########################################
# def data_viewer(label="Pituitary", count=10):
# results = []
# if(label == "Segmentation"):
# for i in range((count//2)+1):
# results.append(f"Images/{label}/original_image_{i}.png")
# results.append(f"Images/{label}/image_with_mask_{i}.png")
# else:
# for i in range(count):
# results.append(f"Images/{label}/{i}.jpg")
# return results
# view_data = gr.Interface(
# fn = data_viewer,
# inputs = [
# gr.Dropdown(
# ["Glioma", "Meningioma", "Pituitary", "Segmentation"], label="Category"
# ),
# gr.Slider(0, 12, value=4, step=2)
# ],
# outputs = [
# gr.Gallery(columns=2),
# ]
# )
# ##########################
from huggingface_hub import InferenceClient
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
mychatbot = gr.Chatbot(
avatar_images=["Chatbot/user.png", "Chatbot/botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
chatbot = gr.ChatInterface(
fn=generate,
chatbot=mychatbot,
examples=[
"What is Brain Tumor and its types?",
"What is a tumor's grade? What does this mean?",
"What are some of the treatment options for Brain Tumor?",
"What causes brain tumors?",
"If I have a brain tumor, can I pass it on to my children?"
],
)
demo = gr.TabbedInterface([classify, detect, chatbot], ["Classification", "Detection", "ChatBot"])
demo.launch()