import io import os import sys sys.path.append(os.path.join(os.path.dirname(__file__))) import blended_tiling import numpy import onnxruntime import streamlit.file_util import torch import torch.cuda from PIL import Image from streamlit.runtime.uploaded_file_manager import UploadedFile from streamlit_image_comparison import image_comparison from torchvision.transforms import functional as TVTF from tools import image_tools # * Cached/loaded model onnx_session = None # type: onnxruntime.InferenceSession # * Streamlit UI / Config streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide") streamlit.title("🐲 PXDN Line Extractor v1") # * Streamlit Containers / Base Layout # Row 1 ui_section_status = streamlit.container() # Row 2 ui_col1, ui_col2 = streamlit.columns(2, gap="medium") streamlit.html("
") # Row 3 ui_section_compare = streamlit.container() # * Streamlit Session # Nothing yet with ui_section_status: # Forward declared UI elements ui_status_text = streamlit.empty() ui_progress_bar = streamlit.empty() with ui_col1: # Input Area streamlit.markdown("### Input Image") ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile with ui_col2: # Output Area streamlit.markdown("### Output Image") # Preallocate image spot and download button ui_image_output = streamlit.empty() ui_image_download = streamlit.empty() def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str: import huggingface_hub return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token) def bootstrap_model(): global onnx_session if onnx_session is None: # Environment-level configuration huggingface_repo = os.getenv("HF_REPO_NAME", "") file_path = os.getenv("HF_FILE_PATH", "") access_token = os.getenv("HF_TOKEN", "") allow_cuda = os.getenv("ALLOW_CUDA", "false").lower() in {'true', 'yes', '1', 'y'} model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token) # * Enable CUDA if available and allowed model_providers = ['CPUExecutionProvider'] if torch.cuda.is_available() and allow_cuda: model_providers.insert(0, 'CUDAExecutionProvider') onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers) def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image: image_pt_orig = image_pt orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2] # ? Padding image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size) _, im_h_padded, im_w_padded = image_pt_padded.shape # ? Tiling image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval() # * Add batch dim for the tiler which expects (1, C, H, W) image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0)) # ? Pull the input and output names from the model so we're not hardcoding them. onnx_session.get_modelmeta() input_name = onnx_session.get_inputs()[0].name output_name = onnx_session.get_outputs()[0].name # ? Inference ================================================================================================== complete_tiles = [] max_evals = image_tiles.size(0) // batch_size image_tiles = image_tiles.numpy() ui_status_text.markdown("### Processing...") active_progress = ui_progress_bar.progress(0, "Progress") for i in range(max_evals): tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size] if len(tile_batch) == 0: break pct_complete = round((i + 1) / max_evals, 2) active_progress.progress(pct_complete) eval_output = onnx_session.run([output_name], {input_name: tile_batch}) output_batch = eval_output[0] complete_tiles.extend(output_batch) # ? /Inference ui_status_text.empty() ui_progress_bar.empty() # ? Rehydrate the tiles into a full image. complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles)) complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor) # ? Unpad the image, a simple crop. if place_x > 0 or place_y > 0: complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w] # ? Clamp and convert to PIL. complete_image = complete_image.squeeze(0) complete_image = complete_image.clamp(0, 1.0) final_image_pil = TVTF.to_pil_image(complete_image) return final_image_pil def streamlit_to_pil_image(streamlit_file: UploadedFile): image = Image.open(io.BytesIO(streamlit_file.read())) return image def pil_to_buffered_png(image: Image.Image) -> io.BytesIO: buffer = io.BytesIO() image.save(buffer, format="PNG", compression=3) buffer.seek(0) return buffer # ! Image Inference if ui_image_input is not None and ui_image_input.name is not None: bootstrap_model() ui_status_text.empty() ui_progress_bar.empty() onnx_session.get_modelmeta() onnx_input_metadata = onnx_session.get_inputs()[0] b, c, h, w = onnx_input_metadata.shape target_batch_size = b # This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it. target_tile_size = h input_image = streamlit_to_pil_image(ui_image_input) loaded_image_pt = image_tools.prepare_image_for_inference(input_image) finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size) with ui_col2: ui_image_output.image(finished_image, use_container_width=True, caption="Output Image") complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png" @streamlit.fragment def download_button(): # ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True)) streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary") download_button() with ui_section_compare: image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024)