kjerk commited on
Commit
d8c02ee
·
1 Parent(s): d2e9922

Add streamlit app code v1

Browse files
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cache and other mess
2
+ /temp/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # IDEs
10
+ /.idea/
11
+ /.vscode/
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
139
+
140
+ # pytype static type analyzer
141
+ .pytype/
142
+
143
+ # Cython debug symbols
144
+ cython_debug/
145
+
146
+ /wandb/
147
+ wandb/
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+
3
+ maxUploadSize = 64
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: Pxdn Line Extractor
3
- emoji: 🌖
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: PXDN Line Extractor
3
+ emoji: ✏️
4
+ colorFrom: pink
5
+ colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
+ license: none
11
  ---
12
 
13
+ # Line Extractor v1
14
+
15
+ An attention based convolutional neural network that extracts lineart from a given rgb image.
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append(os.path.join(os.path.dirname(__file__)))
6
+
7
+ import blended_tiling
8
+ import numpy
9
+ import onnxruntime
10
+ import streamlit.file_util
11
+ import torch
12
+ import torch.cuda
13
+ from PIL import Image
14
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
15
+ from streamlit_image_comparison import image_comparison
16
+ from torchvision.transforms import functional as TVTF
17
+
18
+ from tools import image_tools
19
+
20
+ # * Cached/loaded model
21
+ onnx_session = None # type: onnxruntime.InferenceSession
22
+
23
+ # * Streamlit UI / Config
24
+ streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide")
25
+ streamlit.title("🐲 PXDN Line Extractor v1")
26
+
27
+ # * Streamlit Containers / Base Layout
28
+ # Row 1
29
+ ui_section_status = streamlit.container()
30
+
31
+ # Row 2
32
+ ui_col1, ui_col2 = streamlit.columns(2, gap="medium")
33
+ streamlit.html("<hr>")
34
+
35
+ # Row 3
36
+ ui_section_compare = streamlit.container()
37
+
38
+ # * Streamlit Session
39
+ # Nothing yet
40
+
41
+ with ui_section_status:
42
+ # Forward declared UI elements
43
+ ui_status_text = streamlit.empty()
44
+ ui_progress_bar = streamlit.empty()
45
+
46
+ with ui_col1:
47
+ # Input Area
48
+ streamlit.markdown("### Input Image")
49
+ ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile
50
+
51
+ with ui_col2:
52
+ # Output Area
53
+ streamlit.markdown("### Output Image")
54
+ # Preallocate image spot and download button
55
+ ui_image_output = streamlit.empty()
56
+ ui_image_download = streamlit.empty()
57
+
58
+ def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str:
59
+ import huggingface_hub
60
+ return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token)
61
+
62
+ def bootstrap_model():
63
+ global onnx_session
64
+ if onnx_session is None:
65
+
66
+ # Environment-level configuration
67
+ huggingface_repo = os.getenv("HF_REPO_NAME", "")
68
+ file_path = os.getenv("HF_FILE_PATH", "")
69
+ access_token = os.getenv("HF_TOKEN", "")
70
+ allow_cuda = bool(os.getenv("ALLOW_CUDA", "false").lower())
71
+
72
+ model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token)
73
+
74
+ # * Enable CUDA if available and allowed
75
+ model_providers = ['CPUExecutionProvider']
76
+ if torch.cuda.is_available() and allow_cuda:
77
+ model_providers.insert(0, 'CUDAExecutionProvider')
78
+
79
+ onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers)
80
+
81
+ def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image:
82
+ image_pt_orig = image_pt
83
+ orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2]
84
+
85
+ # ? Padding
86
+ image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size)
87
+
88
+ _, im_h_padded, im_w_padded = image_pt_padded.shape
89
+
90
+ # ? Tiling
91
+ image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval()
92
+ # * Add batch dim for the tiler which expects (1, C, H, W)
93
+ image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0))
94
+
95
+ # ? Pull the input and output names from the model so we're not hardcoding them.
96
+ onnx_session.get_modelmeta()
97
+ input_name = onnx_session.get_inputs()[0].name
98
+ output_name = onnx_session.get_outputs()[0].name
99
+
100
+ # ? Inference ==================================================================================================
101
+ complete_tiles = []
102
+
103
+ max_evals = image_tiles.size(0) // batch_size
104
+ image_tiles = image_tiles.numpy()
105
+
106
+ ui_status_text.markdown("### Processing...")
107
+ active_progress = ui_progress_bar.progress(0, "Progress")
108
+
109
+ for i in range(max_evals):
110
+ tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size]
111
+ if len(tile_batch) == 0:
112
+ break
113
+
114
+ pct_complete = round((i + 1) / max_evals, 2)
115
+ active_progress.progress(pct_complete)
116
+
117
+ eval_output = onnx_session.run([output_name], {input_name: tile_batch})
118
+ output_batch = eval_output[0]
119
+
120
+ complete_tiles.extend(output_batch)
121
+
122
+ # ? /Inference
123
+ ui_status_text.empty()
124
+ ui_progress_bar.empty()
125
+
126
+ # ? Rehydrate the tiles into a full image.
127
+ complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles))
128
+ complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor)
129
+
130
+ # ? Unpad the image, a simple crop.
131
+ if place_x > 0 or place_y > 0:
132
+ complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w]
133
+
134
+ # ? Clamp and convert to PIL.
135
+ complete_image = complete_image.squeeze(0)
136
+ complete_image = complete_image.clamp(0, 1.0)
137
+ final_image_pil = TVTF.to_pil_image(complete_image)
138
+
139
+ return final_image_pil
140
+
141
+ def streamlit_to_pil_image(streamlit_file: UploadedFile):
142
+ image = Image.open(io.BytesIO(streamlit_file.read()))
143
+ return image
144
+
145
+ def pil_to_buffered_png(image: Image.Image) -> io.BytesIO:
146
+ buffer = io.BytesIO()
147
+ image.save(buffer, format="PNG", compression=3)
148
+ buffer.seek(0)
149
+ return buffer
150
+
151
+ # ! Image Inference
152
+ if ui_image_input is not None and ui_image_input.name is not None:
153
+ bootstrap_model()
154
+ ui_status_text.empty()
155
+ ui_progress_bar.empty()
156
+
157
+ onnx_session.get_modelmeta()
158
+ onnx_input_metadata = onnx_session.get_inputs()[0]
159
+ b, c, h, w = onnx_input_metadata.shape
160
+
161
+ target_batch_size = b
162
+ # This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it.
163
+ target_tile_size = h
164
+
165
+ input_image = streamlit_to_pil_image(ui_image_input)
166
+ loaded_image_pt = image_tools.prepare_image_for_inference(input_image)
167
+ finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size)
168
+
169
+ with ui_col2:
170
+ ui_image_output.image(finished_image, use_container_width=True, caption="Output Image")
171
+ complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png"
172
+
173
+ @streamlit.fragment
174
+ def download_button():
175
+ # ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True))
176
+ streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary")
177
+
178
+ download_button()
179
+
180
+ with ui_section_compare:
181
+ image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024)
pycharm_runner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://discuss.streamlit.io/t/cannot-debug-streamlit-in-pycharm-2023-3-3/61581/2
2
+
3
+ try:
4
+ from streamlit.web import bootstrap
5
+ except ImportError:
6
+ from streamlit import bootstrap
7
+
8
+ real_script = 'app.py'
9
+ bootstrap.run(real_script, is_hello=False, args=[], flag_options={})
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ blended_tiling==0.0.1.dev7
2
+ huggingface-hub==0.27.1
3
+ onnx==1.17.0
4
+ onnxruntime==1.17.0
5
+ pillow==11.0.0
6
+ streamlit==1.41.1
7
+ streamlit-image-comparison==0.0.4
8
+ torch
9
+ torchvision
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ if __name__ == '__main__':
3
+ print('__main__ is not supported in modules')
tools/image_tools.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision.transforms import functional as TVTF
4
+
5
+ from .mafs import round_up_to_multiple
6
+
7
+ def prepare_image_for_inference(image_pil: Image.Image) -> torch.Tensor:
8
+ if image_pil.mode != 'RGB':
9
+ image_pil = image_pil.convert('RGB')
10
+
11
+ # Just being explicit, in case of environmental oddities or something.
12
+ image_pt = TVTF.to_tensor(image_pil)
13
+ image_pt = image_pt.to(device='cpu', dtype=torch.float32)
14
+
15
+ return image_pt
16
+
17
+ def pad_to_divisible(image_tensor: torch.Tensor, tile_size: int = 128):
18
+ c, h, w = image_tensor.shape
19
+
20
+ # If the dims are already divisible by the tile size, we're good.
21
+ if h % tile_size == 0 and w % tile_size == 0:
22
+ return image_tensor, 0, 0
23
+
24
+ expanded_w = round_up_to_multiple(w, tile_size)
25
+ expanded_h = round_up_to_multiple(h, tile_size)
26
+ l, t, r, b = 0, 0, 0, 0
27
+
28
+ # Distribute the padding evenly on all sides.
29
+ if expanded_w > w:
30
+ diff = expanded_w - w
31
+ l = diff // 2
32
+ r = diff - l
33
+ if expanded_h > h:
34
+ diff = expanded_h - h
35
+ t = diff // 2
36
+ b = diff - t
37
+
38
+ image_tensor = TVTF.pad(image_tensor, padding=[l, t, r, b], padding_mode='reflect')
39
+
40
+ return image_tensor, l, t
41
+
42
+ if __name__ == '__main__':
43
+ print('__main__ not supported in modules.')
tools/mafs.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def round_up_to_multiple(x: int, multiple: int) -> int:
2
+ return ((x + multiple - 1) // multiple) * multiple
3
+
4
+ if __name__ == '__main__':
5
+ print('__main__ not supported in modules.')