ginipick commited on
Commit
ec38b03
·
verified ·
1 Parent(s): 6a103fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -101
app.py CHANGED
@@ -1,15 +1,7 @@
1
- # --- 패치 적용: 모델 로딩 전에 실행 ---
2
- from transformers import PretrainedConfig
3
- PretrainedConfig.get_text_config = lambda self, decoder=True: type("DummyTextConfig", (), {"tie_word_embeddings": False})()
4
-
5
- # 이미지 분할 모델에 해당하는 모든 클래스의 tie_weights를 빈 함수로 오버라이드
6
- from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
7
- for model_class in MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.values():
8
- model_class.tie_weights = lambda self: None
9
- # --- 패치 종료 ---
10
-
11
- from transformers import AutoModelForImageSegmentation
12
- from transformers import PreTrainedModel # (참고용)
13
  import os
14
  import cv2
15
  import numpy as np
@@ -23,18 +15,50 @@ from typing import Tuple, Optional
23
  from PIL import Image
24
  from gradio_imageslider import ImageSlider
25
  from torchvision import transforms
26
-
27
  import requests
28
  from io import BytesIO
29
  import zipfile
30
  import random
31
 
32
- torch.set_float32_matmul_precision('high')
33
- torch.jit.script = lambda f: f
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- ### 이미지 후처리 함수들 ###
38
  def refine_foreground(image, mask, r=90):
39
  if mask.size != image.size:
40
  mask = mask.resize(image.size)
@@ -61,6 +85,7 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
61
  F = np.clip(F, 0, 1)
62
  return F, blurred_B
63
 
 
64
  class ImagePreprocessor():
65
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
66
  self.transform_image = transforms.Compose([
@@ -72,6 +97,11 @@ class ImagePreprocessor():
72
  image = self.transform_image(image)
73
  return image
74
 
 
 
 
 
 
75
  usage_to_weights_file = {
76
  'General': 'BiRefNet',
77
  'General-HR': 'BiRefNet_HR',
@@ -86,105 +116,113 @@ usage_to_weights_file = {
86
  'General-legacy': 'BiRefNet-legacy'
87
  }
88
 
89
- # 초기 모델 로딩 (기본: General)
90
- birefnet = AutoModelForImageSegmentation.from_pretrained(
91
- '/'.join(('zhengpeng7', usage_to_weights_file['General'])),
92
- trust_remote_code=True
 
 
 
 
 
 
 
 
93
  )
94
- birefnet.to(device)
95
- birefnet.eval(); birefnet.half()
 
 
 
96
 
97
  @spaces.GPU
98
  def predict(images, resolution, weights_file):
 
 
 
 
 
99
  assert images is not None, 'Images cannot be None.'
100
- global birefnet
101
- # 선택된 가중치로 모델 재로딩
102
- _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
103
- print('Using weights: {}.'.format(_weights_file))
104
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
105
- birefnet.to(device)
106
- birefnet.eval(); birefnet.half()
107
 
 
108
  try:
109
- resolution_list = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
 
 
110
  except:
111
- if weights_file == 'General-HR':
112
- resolution_list = [2048, 2048]
113
- elif weights_file == 'General-Lite-2K':
114
- resolution_list = [2560, 1440]
115
- else:
116
- resolution_list = [1024, 1024]
117
- print('Invalid resolution input. Automatically changed to default.')
118
 
119
- # 이미지가 단일 객체인지, 리스트(배치)인지 확인
120
  if isinstance(images, list):
121
- tab_is_batch = True
 
 
 
122
  else:
123
  images = [images]
124
- tab_is_batch = False
125
-
126
- save_paths = []
127
- save_dir = 'preds-BiRefNet'
128
- if tab_is_batch and not os.path.exists(save_dir):
129
- os.makedirs(save_dir)
130
-
131
- outputs = []
132
  for idx, image_src in enumerate(images):
 
133
  if isinstance(image_src, str):
134
  if os.path.isfile(image_src):
135
  image_ori = Image.open(image_src)
136
  else:
137
- response = requests.get(image_src)
138
- image_data = BytesIO(response.content)
139
- image_ori = Image.open(image_data)
 
 
140
  else:
141
- if isinstance(image_src, np.ndarray):
142
- image_ori = Image.fromarray(image_src)
143
- else:
144
- image_ori = image_src.convert('RGB')
145
  image = image_ori.convert('RGB')
146
- preprocessor = ImagePreprocessor(resolution=tuple(resolution_list))
147
- image_proc = preprocessor.proc(image).unsqueeze(0)
148
- with torch.no_grad():
149
- preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
150
- pred = preds[0].squeeze()
151
- pred_pil = transforms.ToPILImage()(pred)
 
 
 
 
 
152
  image_masked = refine_foreground(image, pred_pil)
153
  image_masked.putalpha(pred_pil.resize(image.size))
154
- torch.cuda.empty_cache()
155
- if tab_is_batch:
156
- file_path = os.path.join(save_dir, "{}.png".format(
157
- os.path.splitext(os.path.basename(image_src))[0] if isinstance(image_src, str) else f"img_{idx}"
158
- ))
159
- image_masked.save(file_path)
160
- save_paths.append(file_path)
 
 
 
161
  outputs.append(image_masked)
162
  else:
163
  outputs = [image_masked, image_ori]
164
-
165
- if tab_is_batch:
166
- zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
167
- with zipfile.ZipFile(zip_file_path, 'w') as zipf:
168
- for file in save_paths:
169
- zipf.write(file, os.path.basename(file))
170
- return save_paths, zip_file_path
 
 
 
171
  else:
172
  return outputs
173
 
174
- # 예제 데이터 (이미지, URL, 배치)
175
- examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
176
- examples_text = [[url, "1024x1024", "General"] for url in ["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]]
177
- examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
178
 
179
- descriptions = (
180
- "Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n"
181
- "The resolution used in our training was `1024x1024`, which is suggested for good results! "
182
- "`2048x2048` is suggested for BiRefNet_HR.\n"
183
- "Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
184
- "We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
185
- )
186
 
187
- # UI 개선을 위한 CSS
188
  css = """
189
  body {
190
  background: linear-gradient(135deg, #667eea, #764ba2);
@@ -239,16 +277,17 @@ button:hover, .btn:hover {
239
  }
240
  """
241
 
242
- title = """
243
- <h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo for Subject Extraction</h1>
244
  <p align="center" style="font-size:1.1em; color:#555;">
245
- Upload an image or provide an image URL to extract the subject with high-precision segmentation.
246
  </p>
247
  """
248
 
249
  with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
250
- gr.Markdown(title)
251
  with gr.Tabs():
 
252
  with gr.Tab("Image"):
253
  with gr.Row():
254
  with gr.Column(scale=1):
@@ -257,8 +296,14 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
257
  weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
258
  predict_btn = gr.Button("Predict")
259
  with gr.Column(scale=2):
260
- output_slider = ImageSlider(label="BiRefNet's Prediction", type="pil")
261
- gr.Examples(examples=examples_image, inputs=[image_input, resolution_input, weights_radio], label="Examples")
 
 
 
 
 
 
262
  with gr.Tab("Text"):
263
  with gr.Row():
264
  with gr.Column(scale=1):
@@ -267,23 +312,37 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
267
  weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
268
  predict_btn_text = gr.Button("Predict")
269
  with gr.Column(scale=2):
270
- output_slider_text = ImageSlider(label="BiRefNet's Prediction", type="pil")
271
- gr.Examples(examples=examples_text, inputs=[image_url, resolution_input_text, weights_radio_text], label="Examples")
 
 
 
 
 
 
272
  with gr.Tab("Batch"):
273
  with gr.Row():
274
  with gr.Column(scale=1):
275
- file_input = gr.File(label="Upload Multiple Images", type="filepath", file_count="multiple")
 
 
 
 
276
  resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
277
  weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
278
  predict_btn_batch = gr.Button("Predict")
279
  with gr.Column(scale=2):
280
- output_gallery = gr.Gallery(label="BiRefNet's Predictions", scale=1)
281
- zip_output = gr.File(label="Download Masked Images")
282
- gr.Examples(examples=examples_batch, inputs=[file_input, resolution_input_batch, weights_radio_batch], label="Examples")
283
- with gr.Row():
284
- gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
 
 
 
 
285
 
286
- # 탭의 Predict 버튼과 predict 함수 연결
287
  predict_btn.click(
288
  fn=predict,
289
  inputs=[image_input, resolution_input, weights_radio],
 
1
+ ##########################################################
2
+ # 0. 환경 설정 및 라이브러리 임포트
3
+ ##########################################################
4
+
 
 
 
 
 
 
 
 
5
  import os
6
  import cv2
7
  import numpy as np
 
15
  from PIL import Image
16
  from gradio_imageslider import ImageSlider
17
  from torchvision import transforms
 
18
  import requests
19
  from io import BytesIO
20
  import zipfile
21
  import random
22
 
23
+ # Transformers
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModelForImageSegmentation,
27
+ )
28
+
29
+ # 1) Config를 먼저 로드하여 tie_weights 충돌을 방지
30
+ config = AutoConfig.from_pretrained(
31
+ "zhengpeng7/BiRefNet", # 👉 원하는 Hugging Face 모델 Repo
32
+ trust_remote_code=True
33
+ )
34
+
35
+ # 2) config.get_text_config 에 더미 메서드 부여 (tie_word_embeddings=False)
36
+ def dummy_get_text_config(decoder=True):
37
+ return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
38
 
39
+ config.get_text_config = dummy_get_text_config
40
+
41
+ # 3) 모델 구조만 만들기 (from_config) -> tie_weights 자동 호출 안 됨
42
+ birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
43
+ birefnet.eval()
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ birefnet.to(device)
46
+ birefnet.half()
47
+
48
+ # 4) state_dict 로드 (가중치) - 로컬 파일 사용 예시
49
+ # 실제로는 hf_hub_download / snapshot_download 등으로 "model.safetensors"를 미리 받은 뒤 사용
50
+ print("Loading BiRefNet weights from local file: model.safetensors")
51
+ state_dict = torch.load("model.safetensors", map_location="cpu") # 예시
52
+ missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
53
+ print("[Info] Missing keys:", missing)
54
+ print("[Info] Unexpected keys:", unexpected)
55
+ torch.cuda.empty_cache()
56
+
57
+
58
+ ##########################################################
59
+ # 1. 이미지 후처리 함수들
60
+ ##########################################################
61
 
 
62
  def refine_foreground(image, mask, r=90):
63
  if mask.size != image.size:
64
  mask = mask.resize(image.size)
 
85
  F = np.clip(F, 0, 1)
86
  return F, blurred_B
87
 
88
+
89
  class ImagePreprocessor():
90
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
91
  self.transform_image = transforms.Compose([
 
97
  image = self.transform_image(image)
98
  return image
99
 
100
+
101
+ ##########################################################
102
+ # 2. 예제 설정 및 유틸
103
+ ##########################################################
104
+
105
  usage_to_weights_file = {
106
  'General': 'BiRefNet',
107
  'General-HR': 'BiRefNet_HR',
 
116
  'General-legacy': 'BiRefNet-legacy'
117
  }
118
 
119
+ examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
120
+ examples_text = [[url, "1024x1024", "General"] for url in [
121
+ "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
122
+ ]]
123
+ examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
124
+
125
+ descriptions = (
126
+ "Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n"
127
+ "The resolution used in our training was `1024x1024`, which is suggested for good results! "
128
+ "`2048x2048` is suggested for BiRefNet_HR.\n"
129
+ "Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
130
+ "We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
131
  )
132
+
133
+
134
+ ##########################################################
135
+ # 3. 추론 함수 (이미 로드된 birefnet 모델 사용)
136
+ ##########################################################
137
 
138
  @spaces.GPU
139
  def predict(images, resolution, weights_file):
140
+ """
141
+ 여기서는, 단일 birefnet 모델만 유지하고 있으며,
142
+ weight_file을 바꾸더라도 실제로는 이미 로드된 'birefnet' 모델만 사용.
143
+ (만약 다�� 가중치를 로드하고 싶다면, 아래처럼 로컬 state_dict 교체 방식 추가 가능.)
144
+ """
145
  assert images is not None, 'Images cannot be None.'
 
 
 
 
 
 
 
146
 
147
+ # Resolution parse
148
  try:
149
+ w, h = resolution.strip().split('x')
150
+ w, h = int(int(w)//32*32), int(int(h)//32*32)
151
+ resolution_list = (w, h)
152
  except:
153
+ print('[WARN] Invalid resolution input. Fallback to 1024x1024.')
154
+ resolution_list = (1024, 1024)
 
 
 
 
 
155
 
156
+ # 이미지가 여러 장일 있으므로 리스트로 처리
157
  if isinstance(images, list):
158
+ is_batch = True
159
+ outputs, save_paths = [], []
160
+ save_dir = 'preds-BiRefNet'
161
+ os.makedirs(save_dir, exist_ok=True)
162
  else:
163
  images = [images]
164
+ is_batch = False
165
+
 
 
 
 
 
 
166
  for idx, image_src in enumerate(images):
167
+ # str이면 파일 경로 혹은 URL
168
  if isinstance(image_src, str):
169
  if os.path.isfile(image_src):
170
  image_ori = Image.open(image_src)
171
  else:
172
+ resp = requests.get(image_src)
173
+ image_ori = Image.open(BytesIO(resp.content))
174
+ # numpy 배열이면 Pillow 변환
175
+ elif isinstance(image_src, np.ndarray):
176
+ image_ori = Image.fromarray(image_src)
177
  else:
178
+ image_ori = image_src.convert('RGB')
179
+
 
 
180
  image = image_ori.convert('RGB')
181
+ preproc = ImagePreprocessor(resolution_list)
182
+ image_proc = preproc.proc(image).unsqueeze(0).to(device).half()
183
+
184
+ # 실제 추론
185
+ with torch.inference_mode():
186
+ # 결과 맨 마지막 레이어 preds
187
+ preds = birefnet(image_proc)[-1].sigmoid().cpu()
188
+ pred_mask = preds[0].squeeze()
189
+
190
+ # 후처리
191
+ pred_pil = transforms.ToPILImage()(pred_mask)
192
  image_masked = refine_foreground(image, pred_pil)
193
  image_masked.putalpha(pred_pil.resize(image.size))
194
+
195
+ if is_batch:
196
+ file_name = (
197
+ os.path.splitext(os.path.basename(image_src))[0]
198
+ if isinstance(image_src, str)
199
+ else f"img_{idx}"
200
+ )
201
+ out_path = os.path.join(save_dir, f"{file_name}.png")
202
+ image_masked.save(out_path)
203
+ save_paths.append(out_path)
204
  outputs.append(image_masked)
205
  else:
206
  outputs = [image_masked, image_ori]
207
+
208
+ torch.cuda.empty_cache()
209
+
210
+ # 배치라면 갤러리 + ZIP 반환
211
+ if is_batch:
212
+ zip_path = os.path.join(save_dir, f"{save_dir}.zip")
213
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
214
+ for fpath in save_paths:
215
+ zipf.write(fpath, os.path.basename(fpath))
216
+ return (save_paths, zip_path)
217
  else:
218
  return outputs
219
 
 
 
 
 
220
 
221
+ ##########################################################
222
+ # 4. Gradio UI
223
+ ##########################################################
 
 
 
 
224
 
225
+ # 커스텀 CSS
226
  css = """
227
  body {
228
  background: linear-gradient(135deg, #667eea, #764ba2);
 
277
  }
278
  """
279
 
280
+ title_html = """
281
+ <h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
282
  <p align="center" style="font-size:1.1em; color:#555;">
283
+ Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
284
  </p>
285
  """
286
 
287
  with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
288
+ gr.Markdown(title_html)
289
  with gr.Tabs():
290
+ # 탭 1: Image
291
  with gr.Tab("Image"):
292
  with gr.Row():
293
  with gr.Column(scale=1):
 
296
  weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
297
  predict_btn = gr.Button("Predict")
298
  with gr.Column(scale=2):
299
+ output_slider = ImageSlider(label="Result", type="pil")
300
+ gr.Examples(
301
+ examples=examples_image,
302
+ inputs=[image_input, resolution_input, weights_radio],
303
+ label="Examples"
304
+ )
305
+
306
+ # 탭 2: Text(URL)
307
  with gr.Tab("Text"):
308
  with gr.Row():
309
  with gr.Column(scale=1):
 
312
  weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
313
  predict_btn_text = gr.Button("Predict")
314
  with gr.Column(scale=2):
315
+ output_slider_text = ImageSlider(label="Result", type="pil")
316
+ gr.Examples(
317
+ examples=examples_text,
318
+ inputs=[image_url, resolution_input_text, weights_radio_text],
319
+ label="Examples"
320
+ )
321
+
322
+ # 탭 3: Batch
323
  with gr.Tab("Batch"):
324
  with gr.Row():
325
  with gr.Column(scale=1):
326
+ file_input = gr.File(
327
+ label="Upload Multiple Images",
328
+ type="filepath",
329
+ file_count="multiple"
330
+ )
331
  resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
332
  weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
333
  predict_btn_batch = gr.Button("Predict")
334
  with gr.Column(scale=2):
335
+ output_gallery = gr.Gallery(label="Results", scale=1)
336
+ zip_output = gr.File(label="Zip Download")
337
+ gr.Examples(
338
+ examples=examples_batch,
339
+ inputs=[file_input, resolution_input_batch, weights_radio_batch],
340
+ label="Examples"
341
+ )
342
+
343
+ gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
344
 
345
+ # 버튼 이벤트 연결
346
  predict_btn.click(
347
  fn=predict,
348
  inputs=[image_input, resolution_input, weights_radio],