import io
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__)))
import blended_tiling
import numpy
import onnxruntime
import streamlit.file_util
import torch
import torch.cuda
from PIL import Image
from streamlit.runtime.uploaded_file_manager import UploadedFile
from streamlit_image_comparison import image_comparison
from torchvision.transforms import functional as TVTF
from tools import image_tools
# * Cached/loaded model
onnx_session = None # type: onnxruntime.InferenceSession
# * Streamlit UI / Config
streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide")
streamlit.title("🐲 PXDN Line Extractor v1")
# * Streamlit Containers / Base Layout
# Row 1
ui_section_status = streamlit.container()
# Row 2
ui_col1, ui_col2 = streamlit.columns(2, gap="medium")
streamlit.html("
")
# Row 3
ui_section_compare = streamlit.container()
# * Streamlit Session
# Nothing yet
with ui_section_status:
# Forward declared UI elements
ui_status_text = streamlit.empty()
ui_progress_bar = streamlit.empty()
with ui_col1:
# Input Area
streamlit.markdown("### Input Image")
ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile
with ui_col2:
# Output Area
streamlit.markdown("### Output Image")
# Preallocate image spot and download button
ui_image_output = streamlit.empty()
ui_image_download = streamlit.empty()
def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str:
import huggingface_hub
return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token)
def bootstrap_model():
global onnx_session
if onnx_session is None:
# Environment-level configuration
huggingface_repo = os.getenv("HF_REPO_NAME", "")
file_path = os.getenv("HF_FILE_PATH", "")
access_token = os.getenv("HF_TOKEN", "")
allow_cuda = os.getenv("ALLOW_CUDA", "false").lower() in {'true', 'yes', '1', 'y'}
model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token)
# * Enable CUDA if available and allowed
model_providers = ['CPUExecutionProvider']
if torch.cuda.is_available() and allow_cuda:
model_providers.insert(0, 'CUDAExecutionProvider')
onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers)
def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image:
image_pt_orig = image_pt
orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2]
# ? Padding
image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size)
_, im_h_padded, im_w_padded = image_pt_padded.shape
# ? Tiling
image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval()
# * Add batch dim for the tiler which expects (1, C, H, W)
image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0))
# ? Pull the input and output names from the model so we're not hardcoding them.
onnx_session.get_modelmeta()
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
# ? Inference ==================================================================================================
complete_tiles = []
max_evals = image_tiles.size(0) // batch_size
image_tiles = image_tiles.numpy()
ui_status_text.markdown("### Processing...")
active_progress = ui_progress_bar.progress(0, "Progress")
for i in range(max_evals):
tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size]
if len(tile_batch) == 0:
break
pct_complete = round((i + 1) / max_evals, 2)
active_progress.progress(pct_complete)
eval_output = onnx_session.run([output_name], {input_name: tile_batch})
output_batch = eval_output[0]
complete_tiles.extend(output_batch)
# ? /Inference
ui_status_text.empty()
ui_progress_bar.empty()
# ? Rehydrate the tiles into a full image.
complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles))
complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor)
# ? Unpad the image, a simple crop.
if place_x > 0 or place_y > 0:
complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w]
# ? Clamp and convert to PIL.
complete_image = complete_image.squeeze(0)
complete_image = complete_image.clamp(0, 1.0)
final_image_pil = TVTF.to_pil_image(complete_image)
return final_image_pil
def streamlit_to_pil_image(streamlit_file: UploadedFile):
image = Image.open(io.BytesIO(streamlit_file.read()))
return image
def pil_to_buffered_png(image: Image.Image) -> io.BytesIO:
buffer = io.BytesIO()
image.save(buffer, format="PNG", compression=3)
buffer.seek(0)
return buffer
# ! Image Inference
if ui_image_input is not None and ui_image_input.name is not None:
bootstrap_model()
ui_status_text.empty()
ui_progress_bar.empty()
onnx_session.get_modelmeta()
onnx_input_metadata = onnx_session.get_inputs()[0]
b, c, h, w = onnx_input_metadata.shape
target_batch_size = b
# This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it.
target_tile_size = h
input_image = streamlit_to_pil_image(ui_image_input)
loaded_image_pt = image_tools.prepare_image_for_inference(input_image)
finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size)
with ui_col2:
ui_image_output.image(finished_image, use_container_width=True, caption="Output Image")
complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png"
@streamlit.fragment
def download_button():
# ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True))
streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary")
download_button()
with ui_section_compare:
image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024)