Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from .convert import nifti_to_obj | |
from .css_style import css | |
from .inference import run_model | |
from .logger import flush_logs | |
from .logger import read_logs | |
from .logger import setup_logger | |
from .utils import load_ct_to_numpy | |
from .utils import load_pred_volume_to_numpy | |
# setup logging | |
LOGGER = setup_logger() | |
class WebUI: | |
def __init__( | |
self, | |
model_name: str = None, | |
cwd: str = "/home/user/app/", | |
share: int = 1, | |
): | |
# global states | |
self.images = [] | |
self.pred_images = [] | |
# @TODO: This should be dynamically set based on chosen volume size | |
self.nb_slider_items = 820 | |
self.model_name = model_name | |
self.cwd = cwd | |
self.share = share | |
self.filename = None | |
self.extension = None | |
self.class_name = "airways" # default | |
self.class_names = { | |
"airways": "CT_Airways", | |
"lungs": "CT_Lungs", | |
} | |
self.result_names = { | |
"airways": "Airways", | |
"lungs": "Lungs", | |
} | |
# define widgets not to be rendered immediantly, but later on | |
self.slider = gr.Slider( | |
minimum=1, | |
maximum=self.nb_slider_items, | |
value=1, | |
step=1, | |
label="Which 2D slice to show", | |
) | |
self.volume_renderer = gr.Model3D( | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
label="3D Model", | |
show_label=True, | |
visible=True, | |
elem_id="model-3d", | |
camera_position=[90, 180, 768], | |
height=512, | |
) | |
def set_class_name(self, value): | |
LOGGER.info(f"Changed task to: {value}") | |
self.class_name = value | |
def combine_ct_and_seg(self, img, pred): | |
return (img, [(pred, self.class_name)]) | |
def upload_file(self, file): | |
out = file.name | |
LOGGER.info(f"File uploaded: {out}") | |
return out | |
def process(self, mesh_file_name): | |
path = mesh_file_name.name | |
curr = path.split("/")[-1] | |
self.extension = ".".join(curr.split(".")[1:]) | |
self.filename = ( | |
curr.split(".")[0] + "-" + self.class_names[self.class_name] | |
) | |
run_model( | |
path, | |
model_path=os.path.join(self.cwd, "resources/models/"), | |
task=self.class_names[self.class_name], | |
name=self.result_names[self.class_name], | |
output_filename=self.filename + "." + self.extension, | |
) | |
LOGGER.info("Converting prediction NIfTI to OBJ...") | |
nifti_to_obj(path=self.filename + "." + self.extension) | |
LOGGER.info("Loading CT to numpy...") | |
self.images = load_ct_to_numpy(path) | |
LOGGER.info("Loading prediction volume to numpy..") | |
self.pred_images = load_pred_volume_to_numpy( | |
self.filename + "." + self.extension | |
) | |
return "./prediction.obj" | |
def download_prediction(self): | |
if (self.filename is None) or (self.extension is None): | |
LOGGER.error( | |
"The prediction is not available or ready to download. Wait until the result is available in the 3D viewer." | |
) | |
raise ValueError("Run inference before downloading!") | |
return self.filename + "." + self.extension | |
def get_img_pred_pair(self, k): | |
k = int(k) | |
out = gr.AnnotatedImage( | |
self.combine_ct_and_seg(self.images[k], self.pred_images[k]), | |
visible=True, | |
elem_id="model-2d", | |
color_map={self.class_name: "#ffae00"}, | |
height=512, | |
width=512, | |
) | |
return out | |
def toggle_sidebar(self, state): | |
state = not state | |
return gr.update(visible=state), state | |
def run(self): | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
with gr.Column(visible=True, scale=0.2) as sidebar_left: | |
logs = gr.Textbox( | |
placeholder="\n" * 16, | |
label="Logs", | |
info="Verbose from inference will be displayed below.", | |
lines=36, | |
max_lines=36, | |
autoscroll=True, | |
elem_id="logs", | |
show_copy_button=True, | |
container=True, | |
) | |
demo.load(read_logs, None, logs, every=1) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=150): | |
sidebar_state = gr.State(True) | |
btn_toggle_sidebar = gr.Button( | |
"Toggle Sidebar", | |
elem_id="toggle-button", | |
) | |
btn_toggle_sidebar.click( | |
self.toggle_sidebar, | |
[sidebar_state], | |
[sidebar_left, sidebar_state], | |
) | |
btn_clear_logs = gr.Button( | |
"Clear logs", elem_id="logs-button" | |
) | |
btn_clear_logs.click(flush_logs, [], []) | |
file_output = gr.File( | |
file_count="single", | |
elem_id="upload", | |
scale=3, | |
) | |
file_output.upload( | |
self.upload_file, file_output, file_output | |
) | |
model_selector = gr.Dropdown( | |
list(self.class_names.keys()), | |
label="Task", | |
info="Which structure to segment.", | |
multiselect=False, | |
scale=1, | |
) | |
model_selector.input( | |
fn=lambda x: self.set_class_name(x), | |
inputs=model_selector, | |
outputs=None, | |
) | |
with gr.Column(scale=1, min_width=150): | |
run_btn = gr.Button( | |
"Run analysis", | |
variant="primary", | |
elem_id="run-button", | |
) | |
run_btn.click( | |
fn=lambda x: self.process(x), | |
inputs=file_output, | |
outputs=self.volume_renderer, | |
) | |
download_btn = gr.DownloadButton( | |
"Download prediction", | |
visible=True, | |
variant="secondary", | |
elem_id="download", | |
) | |
download_btn.click( | |
fn=self.download_prediction, | |
inputs=None, | |
outputs=download_btn, | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
os.path.join(self.cwd, "test_thorax_CT.nii.gz"), | |
], | |
inputs=file_output, | |
outputs=file_output, | |
fn=self.upload_file, | |
cache_examples=True, | |
) | |
gr.Markdown( | |
""" | |
**NOTE:** Inference might take several minutes (Airways: ~8 minutes), see logs to the left. \\ | |
The segmentation will be available in the 2D and 3D viewers below when finished. | |
""" | |
) | |
with gr.Row(): | |
with gr.Group(): | |
with gr.Column(): | |
# create dummy image to be replaced by loaded images | |
t = gr.AnnotatedImage( | |
visible=True, | |
elem_id="model-2d", | |
color_map={self.class_name: "#ffae00"}, | |
# height=512, | |
# width=512, | |
) | |
self.slider.input( | |
self.get_img_pred_pair, | |
self.slider, | |
t, | |
) | |
self.slider.render() | |
with gr.Group(): # gr.Box(): | |
self.volume_renderer.render() | |
# sharing app publicly -> share=True: | |
# https://gradio.app/sharing-your-app/ | |
# inference times > 60 seconds -> need queue(): | |
# https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062 | |
demo.queue().launch( | |
server_name="0.0.0.0", server_port=7860, share=self.share | |
) | |