Spaces:
Sleeping
Sleeping
import io | |
import os | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(__file__))) | |
import blended_tiling | |
import numpy | |
import onnxruntime | |
import streamlit.file_util | |
import torch | |
import torch.cuda | |
from PIL import Image | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
from streamlit_image_comparison import image_comparison | |
from torchvision.transforms import functional as TVTF | |
from tools import image_tools | |
# * Cached/loaded model | |
onnx_session = None # type: onnxruntime.InferenceSession | |
# * Streamlit UI / Config | |
streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide") | |
streamlit.title("🐲 PXDN Line Extractor v1") | |
# * Streamlit Containers / Base Layout | |
# Row 1 | |
ui_section_status = streamlit.container() | |
# Row 2 | |
ui_col1, ui_col2 = streamlit.columns(2, gap="medium") | |
streamlit.html("<hr>") | |
# Row 3 | |
ui_section_compare = streamlit.container() | |
# * Streamlit Session | |
# Nothing yet | |
with ui_section_status: | |
# Forward declared UI elements | |
ui_status_text = streamlit.empty() | |
ui_progress_bar = streamlit.empty() | |
with ui_col1: | |
# Input Area | |
streamlit.markdown("### Input Image") | |
ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile | |
with ui_col2: | |
# Output Area | |
streamlit.markdown("### Output Image") | |
# Preallocate image spot and download button | |
ui_image_output = streamlit.empty() | |
ui_image_download = streamlit.empty() | |
def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str: | |
import huggingface_hub | |
return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token) | |
def bootstrap_model(): | |
global onnx_session | |
if onnx_session is None: | |
# Environment-level configuration | |
huggingface_repo = os.getenv("HF_REPO_NAME", "") | |
file_path = os.getenv("HF_FILE_PATH", "") | |
access_token = os.getenv("HF_TOKEN", "") | |
allow_cuda = os.getenv("ALLOW_CUDA", "false").lower() in {'true', 'yes', '1', 'y'} | |
model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token) | |
# * Enable CUDA if available and allowed | |
model_providers = ['CPUExecutionProvider'] | |
if torch.cuda.is_available() and allow_cuda: | |
model_providers.insert(0, 'CUDAExecutionProvider') | |
onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers) | |
def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image: | |
image_pt_orig = image_pt | |
orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2] | |
# ? Padding | |
image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size) | |
_, im_h_padded, im_w_padded = image_pt_padded.shape | |
# ? Tiling | |
image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval() | |
# * Add batch dim for the tiler which expects (1, C, H, W) | |
image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0)) | |
# ? Pull the input and output names from the model so we're not hardcoding them. | |
onnx_session.get_modelmeta() | |
input_name = onnx_session.get_inputs()[0].name | |
output_name = onnx_session.get_outputs()[0].name | |
# ? Inference ================================================================================================== | |
complete_tiles = [] | |
max_evals = image_tiles.size(0) // batch_size | |
image_tiles = image_tiles.numpy() | |
ui_status_text.markdown("### Processing...") | |
active_progress = ui_progress_bar.progress(0, "Progress") | |
for i in range(max_evals): | |
tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size] | |
if len(tile_batch) == 0: | |
break | |
pct_complete = round((i + 1) / max_evals, 2) | |
active_progress.progress(pct_complete) | |
eval_output = onnx_session.run([output_name], {input_name: tile_batch}) | |
output_batch = eval_output[0] | |
complete_tiles.extend(output_batch) | |
# ? /Inference | |
ui_status_text.empty() | |
ui_progress_bar.empty() | |
# ? Rehydrate the tiles into a full image. | |
complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles)) | |
complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor) | |
# ? Unpad the image, a simple crop. | |
if place_x > 0 or place_y > 0: | |
complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w] | |
# ? Clamp and convert to PIL. | |
complete_image = complete_image.squeeze(0) | |
complete_image = complete_image.clamp(0, 1.0) | |
final_image_pil = TVTF.to_pil_image(complete_image) | |
return final_image_pil | |
def streamlit_to_pil_image(streamlit_file: UploadedFile): | |
image = Image.open(io.BytesIO(streamlit_file.read())) | |
return image | |
def pil_to_buffered_png(image: Image.Image) -> io.BytesIO: | |
buffer = io.BytesIO() | |
image.save(buffer, format="PNG", compression=3) | |
buffer.seek(0) | |
return buffer | |
# ! Image Inference | |
if ui_image_input is not None and ui_image_input.name is not None: | |
bootstrap_model() | |
ui_status_text.empty() | |
ui_progress_bar.empty() | |
onnx_session.get_modelmeta() | |
onnx_input_metadata = onnx_session.get_inputs()[0] | |
b, c, h, w = onnx_input_metadata.shape | |
target_batch_size = b | |
# This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it. | |
target_tile_size = h | |
input_image = streamlit_to_pil_image(ui_image_input) | |
loaded_image_pt = image_tools.prepare_image_for_inference(input_image) | |
finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size) | |
with ui_col2: | |
ui_image_output.image(finished_image, use_container_width=True, caption="Output Image") | |
complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png" | |
def download_button(): | |
# ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True)) | |
streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary") | |
download_button() | |
with ui_section_compare: | |
image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024) | |