Spaces:
Sleeping
Sleeping
Add streamlit app code v1
Browse files- .gitignore +147 -0
- .streamlit/config.toml +3 -0
- README.md +8 -5
- app.py +181 -0
- pycharm_runner.py +9 -0
- requirements.txt +9 -0
- tools/__init__.py +3 -0
- tools/image_tools.py +43 -0
- tools/mafs.py +5 -0
.gitignore
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cache and other mess
|
2 |
+
/temp/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# IDEs
|
10 |
+
/.idea/
|
11 |
+
/.vscode/
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
104 |
+
__pypackages__/
|
105 |
+
|
106 |
+
# Celery stuff
|
107 |
+
celerybeat-schedule
|
108 |
+
celerybeat.pid
|
109 |
+
|
110 |
+
# SageMath parsed files
|
111 |
+
*.sage.py
|
112 |
+
|
113 |
+
# Environments
|
114 |
+
.env
|
115 |
+
.venv
|
116 |
+
env/
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
env.bak/
|
120 |
+
venv.bak/
|
121 |
+
|
122 |
+
# Spyder project settings
|
123 |
+
.spyderproject
|
124 |
+
.spyproject
|
125 |
+
|
126 |
+
# Rope project settings
|
127 |
+
.ropeproject
|
128 |
+
|
129 |
+
# mkdocs documentation
|
130 |
+
/site
|
131 |
+
|
132 |
+
# mypy
|
133 |
+
.mypy_cache/
|
134 |
+
.dmypy.json
|
135 |
+
dmypy.json
|
136 |
+
|
137 |
+
# Pyre type checker
|
138 |
+
.pyre/
|
139 |
+
|
140 |
+
# pytype static type analyzer
|
141 |
+
.pytype/
|
142 |
+
|
143 |
+
# Cython debug symbols
|
144 |
+
cython_debug/
|
145 |
+
|
146 |
+
/wandb/
|
147 |
+
wandb/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
|
3 |
+
maxUploadSize = 64
|
README.md
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.41.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
1 |
---
|
2 |
+
title: PXDN Line Extractor
|
3 |
+
emoji: ✏️
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.41.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: none
|
11 |
---
|
12 |
|
13 |
+
# Line Extractor v1
|
14 |
+
|
15 |
+
An attention based convolutional neural network that extracts lineart from a given rgb image.
|
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.append(os.path.join(os.path.dirname(__file__)))
|
6 |
+
|
7 |
+
import blended_tiling
|
8 |
+
import numpy
|
9 |
+
import onnxruntime
|
10 |
+
import streamlit.file_util
|
11 |
+
import torch
|
12 |
+
import torch.cuda
|
13 |
+
from PIL import Image
|
14 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
15 |
+
from streamlit_image_comparison import image_comparison
|
16 |
+
from torchvision.transforms import functional as TVTF
|
17 |
+
|
18 |
+
from tools import image_tools
|
19 |
+
|
20 |
+
# * Cached/loaded model
|
21 |
+
onnx_session = None # type: onnxruntime.InferenceSession
|
22 |
+
|
23 |
+
# * Streamlit UI / Config
|
24 |
+
streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide")
|
25 |
+
streamlit.title("🐲 PXDN Line Extractor v1")
|
26 |
+
|
27 |
+
# * Streamlit Containers / Base Layout
|
28 |
+
# Row 1
|
29 |
+
ui_section_status = streamlit.container()
|
30 |
+
|
31 |
+
# Row 2
|
32 |
+
ui_col1, ui_col2 = streamlit.columns(2, gap="medium")
|
33 |
+
streamlit.html("<hr>")
|
34 |
+
|
35 |
+
# Row 3
|
36 |
+
ui_section_compare = streamlit.container()
|
37 |
+
|
38 |
+
# * Streamlit Session
|
39 |
+
# Nothing yet
|
40 |
+
|
41 |
+
with ui_section_status:
|
42 |
+
# Forward declared UI elements
|
43 |
+
ui_status_text = streamlit.empty()
|
44 |
+
ui_progress_bar = streamlit.empty()
|
45 |
+
|
46 |
+
with ui_col1:
|
47 |
+
# Input Area
|
48 |
+
streamlit.markdown("### Input Image")
|
49 |
+
ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile
|
50 |
+
|
51 |
+
with ui_col2:
|
52 |
+
# Output Area
|
53 |
+
streamlit.markdown("### Output Image")
|
54 |
+
# Preallocate image spot and download button
|
55 |
+
ui_image_output = streamlit.empty()
|
56 |
+
ui_image_download = streamlit.empty()
|
57 |
+
|
58 |
+
def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str:
|
59 |
+
import huggingface_hub
|
60 |
+
return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token)
|
61 |
+
|
62 |
+
def bootstrap_model():
|
63 |
+
global onnx_session
|
64 |
+
if onnx_session is None:
|
65 |
+
|
66 |
+
# Environment-level configuration
|
67 |
+
huggingface_repo = os.getenv("HF_REPO_NAME", "")
|
68 |
+
file_path = os.getenv("HF_FILE_PATH", "")
|
69 |
+
access_token = os.getenv("HF_TOKEN", "")
|
70 |
+
allow_cuda = bool(os.getenv("ALLOW_CUDA", "false").lower())
|
71 |
+
|
72 |
+
model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token)
|
73 |
+
|
74 |
+
# * Enable CUDA if available and allowed
|
75 |
+
model_providers = ['CPUExecutionProvider']
|
76 |
+
if torch.cuda.is_available() and allow_cuda:
|
77 |
+
model_providers.insert(0, 'CUDAExecutionProvider')
|
78 |
+
|
79 |
+
onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers)
|
80 |
+
|
81 |
+
def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image:
|
82 |
+
image_pt_orig = image_pt
|
83 |
+
orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2]
|
84 |
+
|
85 |
+
# ? Padding
|
86 |
+
image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size)
|
87 |
+
|
88 |
+
_, im_h_padded, im_w_padded = image_pt_padded.shape
|
89 |
+
|
90 |
+
# ? Tiling
|
91 |
+
image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval()
|
92 |
+
# * Add batch dim for the tiler which expects (1, C, H, W)
|
93 |
+
image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0))
|
94 |
+
|
95 |
+
# ? Pull the input and output names from the model so we're not hardcoding them.
|
96 |
+
onnx_session.get_modelmeta()
|
97 |
+
input_name = onnx_session.get_inputs()[0].name
|
98 |
+
output_name = onnx_session.get_outputs()[0].name
|
99 |
+
|
100 |
+
# ? Inference ==================================================================================================
|
101 |
+
complete_tiles = []
|
102 |
+
|
103 |
+
max_evals = image_tiles.size(0) // batch_size
|
104 |
+
image_tiles = image_tiles.numpy()
|
105 |
+
|
106 |
+
ui_status_text.markdown("### Processing...")
|
107 |
+
active_progress = ui_progress_bar.progress(0, "Progress")
|
108 |
+
|
109 |
+
for i in range(max_evals):
|
110 |
+
tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size]
|
111 |
+
if len(tile_batch) == 0:
|
112 |
+
break
|
113 |
+
|
114 |
+
pct_complete = round((i + 1) / max_evals, 2)
|
115 |
+
active_progress.progress(pct_complete)
|
116 |
+
|
117 |
+
eval_output = onnx_session.run([output_name], {input_name: tile_batch})
|
118 |
+
output_batch = eval_output[0]
|
119 |
+
|
120 |
+
complete_tiles.extend(output_batch)
|
121 |
+
|
122 |
+
# ? /Inference
|
123 |
+
ui_status_text.empty()
|
124 |
+
ui_progress_bar.empty()
|
125 |
+
|
126 |
+
# ? Rehydrate the tiles into a full image.
|
127 |
+
complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles))
|
128 |
+
complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor)
|
129 |
+
|
130 |
+
# ? Unpad the image, a simple crop.
|
131 |
+
if place_x > 0 or place_y > 0:
|
132 |
+
complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w]
|
133 |
+
|
134 |
+
# ? Clamp and convert to PIL.
|
135 |
+
complete_image = complete_image.squeeze(0)
|
136 |
+
complete_image = complete_image.clamp(0, 1.0)
|
137 |
+
final_image_pil = TVTF.to_pil_image(complete_image)
|
138 |
+
|
139 |
+
return final_image_pil
|
140 |
+
|
141 |
+
def streamlit_to_pil_image(streamlit_file: UploadedFile):
|
142 |
+
image = Image.open(io.BytesIO(streamlit_file.read()))
|
143 |
+
return image
|
144 |
+
|
145 |
+
def pil_to_buffered_png(image: Image.Image) -> io.BytesIO:
|
146 |
+
buffer = io.BytesIO()
|
147 |
+
image.save(buffer, format="PNG", compression=3)
|
148 |
+
buffer.seek(0)
|
149 |
+
return buffer
|
150 |
+
|
151 |
+
# ! Image Inference
|
152 |
+
if ui_image_input is not None and ui_image_input.name is not None:
|
153 |
+
bootstrap_model()
|
154 |
+
ui_status_text.empty()
|
155 |
+
ui_progress_bar.empty()
|
156 |
+
|
157 |
+
onnx_session.get_modelmeta()
|
158 |
+
onnx_input_metadata = onnx_session.get_inputs()[0]
|
159 |
+
b, c, h, w = onnx_input_metadata.shape
|
160 |
+
|
161 |
+
target_batch_size = b
|
162 |
+
# This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it.
|
163 |
+
target_tile_size = h
|
164 |
+
|
165 |
+
input_image = streamlit_to_pil_image(ui_image_input)
|
166 |
+
loaded_image_pt = image_tools.prepare_image_for_inference(input_image)
|
167 |
+
finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size)
|
168 |
+
|
169 |
+
with ui_col2:
|
170 |
+
ui_image_output.image(finished_image, use_container_width=True, caption="Output Image")
|
171 |
+
complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png"
|
172 |
+
|
173 |
+
@streamlit.fragment
|
174 |
+
def download_button():
|
175 |
+
# ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True))
|
176 |
+
streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary")
|
177 |
+
|
178 |
+
download_button()
|
179 |
+
|
180 |
+
with ui_section_compare:
|
181 |
+
image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024)
|
pycharm_runner.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://discuss.streamlit.io/t/cannot-debug-streamlit-in-pycharm-2023-3-3/61581/2
|
2 |
+
|
3 |
+
try:
|
4 |
+
from streamlit.web import bootstrap
|
5 |
+
except ImportError:
|
6 |
+
from streamlit import bootstrap
|
7 |
+
|
8 |
+
real_script = 'app.py'
|
9 |
+
bootstrap.run(real_script, is_hello=False, args=[], flag_options={})
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
blended_tiling==0.0.1.dev7
|
2 |
+
huggingface-hub==0.27.1
|
3 |
+
onnx==1.17.0
|
4 |
+
onnxruntime==1.17.0
|
5 |
+
pillow==11.0.0
|
6 |
+
streamlit==1.41.1
|
7 |
+
streamlit-image-comparison==0.0.4
|
8 |
+
torch
|
9 |
+
torchvision
|
tools/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
if __name__ == '__main__':
|
3 |
+
print('__main__ is not supported in modules')
|
tools/image_tools.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision.transforms import functional as TVTF
|
4 |
+
|
5 |
+
from .mafs import round_up_to_multiple
|
6 |
+
|
7 |
+
def prepare_image_for_inference(image_pil: Image.Image) -> torch.Tensor:
|
8 |
+
if image_pil.mode != 'RGB':
|
9 |
+
image_pil = image_pil.convert('RGB')
|
10 |
+
|
11 |
+
# Just being explicit, in case of environmental oddities or something.
|
12 |
+
image_pt = TVTF.to_tensor(image_pil)
|
13 |
+
image_pt = image_pt.to(device='cpu', dtype=torch.float32)
|
14 |
+
|
15 |
+
return image_pt
|
16 |
+
|
17 |
+
def pad_to_divisible(image_tensor: torch.Tensor, tile_size: int = 128):
|
18 |
+
c, h, w = image_tensor.shape
|
19 |
+
|
20 |
+
# If the dims are already divisible by the tile size, we're good.
|
21 |
+
if h % tile_size == 0 and w % tile_size == 0:
|
22 |
+
return image_tensor, 0, 0
|
23 |
+
|
24 |
+
expanded_w = round_up_to_multiple(w, tile_size)
|
25 |
+
expanded_h = round_up_to_multiple(h, tile_size)
|
26 |
+
l, t, r, b = 0, 0, 0, 0
|
27 |
+
|
28 |
+
# Distribute the padding evenly on all sides.
|
29 |
+
if expanded_w > w:
|
30 |
+
diff = expanded_w - w
|
31 |
+
l = diff // 2
|
32 |
+
r = diff - l
|
33 |
+
if expanded_h > h:
|
34 |
+
diff = expanded_h - h
|
35 |
+
t = diff // 2
|
36 |
+
b = diff - t
|
37 |
+
|
38 |
+
image_tensor = TVTF.pad(image_tensor, padding=[l, t, r, b], padding_mode='reflect')
|
39 |
+
|
40 |
+
return image_tensor, l, t
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
print('__main__ not supported in modules.')
|
tools/mafs.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def round_up_to_multiple(x: int, multiple: int) -> int:
|
2 |
+
return ((x + multiple - 1) // multiple) * multiple
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
print('__main__ not supported in modules.')
|