Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
# 이미지 분할 모델에 해당하는 모든 클래스의 tie_weights를 빈 함수로 오버라이드
|
6 |
-
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
7 |
-
for model_class in MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.values():
|
8 |
-
model_class.tie_weights = lambda self: None
|
9 |
-
# --- 패치 종료 ---
|
10 |
-
|
11 |
-
from transformers import AutoModelForImageSegmentation
|
12 |
-
from transformers import PreTrainedModel # (참고용)
|
13 |
import os
|
14 |
import cv2
|
15 |
import numpy as np
|
@@ -23,18 +15,50 @@ from typing import Tuple, Optional
|
|
23 |
from PIL import Image
|
24 |
from gradio_imageslider import ImageSlider
|
25 |
from torchvision import transforms
|
26 |
-
|
27 |
import requests
|
28 |
from io import BytesIO
|
29 |
import zipfile
|
30 |
import random
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
### 이미지 후처리 함수들 ###
|
38 |
def refine_foreground(image, mask, r=90):
|
39 |
if mask.size != image.size:
|
40 |
mask = mask.resize(image.size)
|
@@ -61,6 +85,7 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
61 |
F = np.clip(F, 0, 1)
|
62 |
return F, blurred_B
|
63 |
|
|
|
64 |
class ImagePreprocessor():
|
65 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
66 |
self.transform_image = transforms.Compose([
|
@@ -72,6 +97,11 @@ class ImagePreprocessor():
|
|
72 |
image = self.transform_image(image)
|
73 |
return image
|
74 |
|
|
|
|
|
|
|
|
|
|
|
75 |
usage_to_weights_file = {
|
76 |
'General': 'BiRefNet',
|
77 |
'General-HR': 'BiRefNet_HR',
|
@@ -86,105 +116,113 @@ usage_to_weights_file = {
|
|
86 |
'General-legacy': 'BiRefNet-legacy'
|
87 |
}
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
)
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
@spaces.GPU
|
98 |
def predict(images, resolution, weights_file):
|
|
|
|
|
|
|
|
|
|
|
99 |
assert images is not None, 'Images cannot be None.'
|
100 |
-
global birefnet
|
101 |
-
# 선택된 가중치로 모델 재로딩
|
102 |
-
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
103 |
-
print('Using weights: {}.'.format(_weights_file))
|
104 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
105 |
-
birefnet.to(device)
|
106 |
-
birefnet.eval(); birefnet.half()
|
107 |
|
|
|
108 |
try:
|
109 |
-
|
|
|
|
|
110 |
except:
|
111 |
-
|
112 |
-
|
113 |
-
elif weights_file == 'General-Lite-2K':
|
114 |
-
resolution_list = [2560, 1440]
|
115 |
-
else:
|
116 |
-
resolution_list = [1024, 1024]
|
117 |
-
print('Invalid resolution input. Automatically changed to default.')
|
118 |
|
119 |
-
# 이미지가
|
120 |
if isinstance(images, list):
|
121 |
-
|
|
|
|
|
|
|
122 |
else:
|
123 |
images = [images]
|
124 |
-
|
125 |
-
|
126 |
-
save_paths = []
|
127 |
-
save_dir = 'preds-BiRefNet'
|
128 |
-
if tab_is_batch and not os.path.exists(save_dir):
|
129 |
-
os.makedirs(save_dir)
|
130 |
-
|
131 |
-
outputs = []
|
132 |
for idx, image_src in enumerate(images):
|
|
|
133 |
if isinstance(image_src, str):
|
134 |
if os.path.isfile(image_src):
|
135 |
image_ori = Image.open(image_src)
|
136 |
else:
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
140 |
else:
|
141 |
-
|
142 |
-
|
143 |
-
else:
|
144 |
-
image_ori = image_src.convert('RGB')
|
145 |
image = image_ori.convert('RGB')
|
146 |
-
|
147 |
-
image_proc =
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
image_masked = refine_foreground(image, pred_pil)
|
153 |
image_masked.putalpha(pred_pil.resize(image.size))
|
154 |
-
|
155 |
-
if
|
156 |
-
|
157 |
-
os.path.splitext(os.path.basename(image_src))[0]
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
161 |
outputs.append(image_masked)
|
162 |
else:
|
163 |
outputs = [image_masked, image_ori]
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
else:
|
172 |
return outputs
|
173 |
|
174 |
-
# 예제 데이터 (이미지, URL, 배치)
|
175 |
-
examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
|
176 |
-
examples_text = [[url, "1024x1024", "General"] for url in ["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]]
|
177 |
-
examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
"`2048x2048` is suggested for BiRefNet_HR.\n"
|
183 |
-
"Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
|
184 |
-
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
185 |
-
)
|
186 |
|
187 |
-
#
|
188 |
css = """
|
189 |
body {
|
190 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
@@ -239,16 +277,17 @@ button:hover, .btn:hover {
|
|
239 |
}
|
240 |
"""
|
241 |
|
242 |
-
|
243 |
-
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo
|
244 |
<p align="center" style="font-size:1.1em; color:#555;">
|
245 |
-
|
246 |
</p>
|
247 |
"""
|
248 |
|
249 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
250 |
-
gr.Markdown(
|
251 |
with gr.Tabs():
|
|
|
252 |
with gr.Tab("Image"):
|
253 |
with gr.Row():
|
254 |
with gr.Column(scale=1):
|
@@ -257,8 +296,14 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
257 |
weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
258 |
predict_btn = gr.Button("Predict")
|
259 |
with gr.Column(scale=2):
|
260 |
-
output_slider = ImageSlider(label="
|
261 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
with gr.Tab("Text"):
|
263 |
with gr.Row():
|
264 |
with gr.Column(scale=1):
|
@@ -267,23 +312,37 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
267 |
weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
268 |
predict_btn_text = gr.Button("Predict")
|
269 |
with gr.Column(scale=2):
|
270 |
-
output_slider_text = ImageSlider(label="
|
271 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
with gr.Tab("Batch"):
|
273 |
with gr.Row():
|
274 |
with gr.Column(scale=1):
|
275 |
-
file_input = gr.File(
|
|
|
|
|
|
|
|
|
276 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
277 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
278 |
predict_btn_batch = gr.Button("Predict")
|
279 |
with gr.Column(scale=2):
|
280 |
-
output_gallery = gr.Gallery(label="
|
281 |
-
zip_output = gr.File(label="Download
|
282 |
-
gr.Examples(
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
#
|
287 |
predict_btn.click(
|
288 |
fn=predict,
|
289 |
inputs=[image_input, resolution_input, weights_radio],
|
|
|
1 |
+
##########################################################
|
2 |
+
# 0. 환경 설정 및 라이브러리 임포트
|
3 |
+
##########################################################
|
4 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
import cv2
|
7 |
import numpy as np
|
|
|
15 |
from PIL import Image
|
16 |
from gradio_imageslider import ImageSlider
|
17 |
from torchvision import transforms
|
|
|
18 |
import requests
|
19 |
from io import BytesIO
|
20 |
import zipfile
|
21 |
import random
|
22 |
|
23 |
+
# Transformers
|
24 |
+
from transformers import (
|
25 |
+
AutoConfig,
|
26 |
+
AutoModelForImageSegmentation,
|
27 |
+
)
|
28 |
+
|
29 |
+
# 1) Config를 먼저 로드하여 tie_weights 충돌을 방지
|
30 |
+
config = AutoConfig.from_pretrained(
|
31 |
+
"zhengpeng7/BiRefNet", # 👉 원하는 Hugging Face 모델 Repo
|
32 |
+
trust_remote_code=True
|
33 |
+
)
|
34 |
+
|
35 |
+
# 2) config.get_text_config 에 더미 메서드 부여 (tie_word_embeddings=False)
|
36 |
+
def dummy_get_text_config(decoder=True):
|
37 |
+
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
38 |
|
39 |
+
config.get_text_config = dummy_get_text_config
|
40 |
+
|
41 |
+
# 3) 모델 구조만 만들기 (from_config) -> tie_weights 자동 호출 안 됨
|
42 |
+
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
43 |
+
birefnet.eval()
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
birefnet.to(device)
|
46 |
+
birefnet.half()
|
47 |
+
|
48 |
+
# 4) state_dict 로드 (가중치) - 로컬 파일 사용 예시
|
49 |
+
# 실제로는 hf_hub_download / snapshot_download 등으로 "model.safetensors"를 미리 받은 뒤 사용
|
50 |
+
print("Loading BiRefNet weights from local file: model.safetensors")
|
51 |
+
state_dict = torch.load("model.safetensors", map_location="cpu") # 예시
|
52 |
+
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
53 |
+
print("[Info] Missing keys:", missing)
|
54 |
+
print("[Info] Unexpected keys:", unexpected)
|
55 |
+
torch.cuda.empty_cache()
|
56 |
+
|
57 |
+
|
58 |
+
##########################################################
|
59 |
+
# 1. 이미지 후처리 함수들
|
60 |
+
##########################################################
|
61 |
|
|
|
62 |
def refine_foreground(image, mask, r=90):
|
63 |
if mask.size != image.size:
|
64 |
mask = mask.resize(image.size)
|
|
|
85 |
F = np.clip(F, 0, 1)
|
86 |
return F, blurred_B
|
87 |
|
88 |
+
|
89 |
class ImagePreprocessor():
|
90 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
91 |
self.transform_image = transforms.Compose([
|
|
|
97 |
image = self.transform_image(image)
|
98 |
return image
|
99 |
|
100 |
+
|
101 |
+
##########################################################
|
102 |
+
# 2. 예제 설정 및 유틸
|
103 |
+
##########################################################
|
104 |
+
|
105 |
usage_to_weights_file = {
|
106 |
'General': 'BiRefNet',
|
107 |
'General-HR': 'BiRefNet_HR',
|
|
|
116 |
'General-legacy': 'BiRefNet-legacy'
|
117 |
}
|
118 |
|
119 |
+
examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
|
120 |
+
examples_text = [[url, "1024x1024", "General"] for url in [
|
121 |
+
"https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
122 |
+
]]
|
123 |
+
examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
|
124 |
+
|
125 |
+
descriptions = (
|
126 |
+
"Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n"
|
127 |
+
"The resolution used in our training was `1024x1024`, which is suggested for good results! "
|
128 |
+
"`2048x2048` is suggested for BiRefNet_HR.\n"
|
129 |
+
"Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
|
130 |
+
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
131 |
)
|
132 |
+
|
133 |
+
|
134 |
+
##########################################################
|
135 |
+
# 3. 추론 함수 (이미 로드된 birefnet 모델 사용)
|
136 |
+
##########################################################
|
137 |
|
138 |
@spaces.GPU
|
139 |
def predict(images, resolution, weights_file):
|
140 |
+
"""
|
141 |
+
여기서는, 단일 birefnet 모델만 유지하고 있으며,
|
142 |
+
weight_file을 바꾸더라도 실제로는 이미 로드된 'birefnet' 모델만 사용.
|
143 |
+
(만약 다�� 가중치를 로드하고 싶다면, 아래처럼 로컬 state_dict 교체 방식 추가 가능.)
|
144 |
+
"""
|
145 |
assert images is not None, 'Images cannot be None.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
# Resolution parse
|
148 |
try:
|
149 |
+
w, h = resolution.strip().split('x')
|
150 |
+
w, h = int(int(w)//32*32), int(int(h)//32*32)
|
151 |
+
resolution_list = (w, h)
|
152 |
except:
|
153 |
+
print('[WARN] Invalid resolution input. Fallback to 1024x1024.')
|
154 |
+
resolution_list = (1024, 1024)
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
# 이미지가 여러 장일 수 있으므로 리스트로 처리
|
157 |
if isinstance(images, list):
|
158 |
+
is_batch = True
|
159 |
+
outputs, save_paths = [], []
|
160 |
+
save_dir = 'preds-BiRefNet'
|
161 |
+
os.makedirs(save_dir, exist_ok=True)
|
162 |
else:
|
163 |
images = [images]
|
164 |
+
is_batch = False
|
165 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
for idx, image_src in enumerate(images):
|
167 |
+
# str이면 파일 경로 혹은 URL
|
168 |
if isinstance(image_src, str):
|
169 |
if os.path.isfile(image_src):
|
170 |
image_ori = Image.open(image_src)
|
171 |
else:
|
172 |
+
resp = requests.get(image_src)
|
173 |
+
image_ori = Image.open(BytesIO(resp.content))
|
174 |
+
# numpy 배열이면 Pillow 변환
|
175 |
+
elif isinstance(image_src, np.ndarray):
|
176 |
+
image_ori = Image.fromarray(image_src)
|
177 |
else:
|
178 |
+
image_ori = image_src.convert('RGB')
|
179 |
+
|
|
|
|
|
180 |
image = image_ori.convert('RGB')
|
181 |
+
preproc = ImagePreprocessor(resolution_list)
|
182 |
+
image_proc = preproc.proc(image).unsqueeze(0).to(device).half()
|
183 |
+
|
184 |
+
# 실제 추론
|
185 |
+
with torch.inference_mode():
|
186 |
+
# 결과 맨 마지막 레이어 preds
|
187 |
+
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
188 |
+
pred_mask = preds[0].squeeze()
|
189 |
+
|
190 |
+
# 후처리
|
191 |
+
pred_pil = transforms.ToPILImage()(pred_mask)
|
192 |
image_masked = refine_foreground(image, pred_pil)
|
193 |
image_masked.putalpha(pred_pil.resize(image.size))
|
194 |
+
|
195 |
+
if is_batch:
|
196 |
+
file_name = (
|
197 |
+
os.path.splitext(os.path.basename(image_src))[0]
|
198 |
+
if isinstance(image_src, str)
|
199 |
+
else f"img_{idx}"
|
200 |
+
)
|
201 |
+
out_path = os.path.join(save_dir, f"{file_name}.png")
|
202 |
+
image_masked.save(out_path)
|
203 |
+
save_paths.append(out_path)
|
204 |
outputs.append(image_masked)
|
205 |
else:
|
206 |
outputs = [image_masked, image_ori]
|
207 |
+
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
|
210 |
+
# 배치라면 갤러리 + ZIP 반환
|
211 |
+
if is_batch:
|
212 |
+
zip_path = os.path.join(save_dir, f"{save_dir}.zip")
|
213 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
214 |
+
for fpath in save_paths:
|
215 |
+
zipf.write(fpath, os.path.basename(fpath))
|
216 |
+
return (save_paths, zip_path)
|
217 |
else:
|
218 |
return outputs
|
219 |
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
##########################################################
|
222 |
+
# 4. Gradio UI
|
223 |
+
##########################################################
|
|
|
|
|
|
|
|
|
224 |
|
225 |
+
# 커스텀 CSS
|
226 |
css = """
|
227 |
body {
|
228 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
|
277 |
}
|
278 |
"""
|
279 |
|
280 |
+
title_html = """
|
281 |
+
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
282 |
<p align="center" style="font-size:1.1em; color:#555;">
|
283 |
+
Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
|
284 |
</p>
|
285 |
"""
|
286 |
|
287 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
288 |
+
gr.Markdown(title_html)
|
289 |
with gr.Tabs():
|
290 |
+
# 탭 1: Image
|
291 |
with gr.Tab("Image"):
|
292 |
with gr.Row():
|
293 |
with gr.Column(scale=1):
|
|
|
296 |
weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
297 |
predict_btn = gr.Button("Predict")
|
298 |
with gr.Column(scale=2):
|
299 |
+
output_slider = ImageSlider(label="Result", type="pil")
|
300 |
+
gr.Examples(
|
301 |
+
examples=examples_image,
|
302 |
+
inputs=[image_input, resolution_input, weights_radio],
|
303 |
+
label="Examples"
|
304 |
+
)
|
305 |
+
|
306 |
+
# 탭 2: Text(URL)
|
307 |
with gr.Tab("Text"):
|
308 |
with gr.Row():
|
309 |
with gr.Column(scale=1):
|
|
|
312 |
weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
313 |
predict_btn_text = gr.Button("Predict")
|
314 |
with gr.Column(scale=2):
|
315 |
+
output_slider_text = ImageSlider(label="Result", type="pil")
|
316 |
+
gr.Examples(
|
317 |
+
examples=examples_text,
|
318 |
+
inputs=[image_url, resolution_input_text, weights_radio_text],
|
319 |
+
label="Examples"
|
320 |
+
)
|
321 |
+
|
322 |
+
# 탭 3: Batch
|
323 |
with gr.Tab("Batch"):
|
324 |
with gr.Row():
|
325 |
with gr.Column(scale=1):
|
326 |
+
file_input = gr.File(
|
327 |
+
label="Upload Multiple Images",
|
328 |
+
type="filepath",
|
329 |
+
file_count="multiple"
|
330 |
+
)
|
331 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
332 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
333 |
predict_btn_batch = gr.Button("Predict")
|
334 |
with gr.Column(scale=2):
|
335 |
+
output_gallery = gr.Gallery(label="Results", scale=1)
|
336 |
+
zip_output = gr.File(label="Zip Download")
|
337 |
+
gr.Examples(
|
338 |
+
examples=examples_batch,
|
339 |
+
inputs=[file_input, resolution_input_batch, weights_radio_batch],
|
340 |
+
label="Examples"
|
341 |
+
)
|
342 |
+
|
343 |
+
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
344 |
|
345 |
+
# 버튼 이벤트 연결
|
346 |
predict_btn.click(
|
347 |
fn=predict,
|
348 |
inputs=[image_input, resolution_input, weights_radio],
|