0notexist0 commited on
Commit
c437df7
Β·
verified Β·
1 Parent(s): 61bdfa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -56
app.py CHANGED
@@ -1,60 +1,30 @@
1
  import gradio as gr
2
  from loadimg import load_img
3
- import spaces
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
7
  from typing import Union, Tuple
8
  from PIL import Image
9
 
10
- torch.set_float32_matmul_precision(["high", "highest"][0])
 
 
11
 
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
15
- birefnet.to("cuda")
16
-
17
- transform_image = transforms.Compose(
18
- [
19
- transforms.Resize((1024, 1024)),
20
- transforms.ToTensor(),
21
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
- ]
23
- )
24
 
25
- def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
26
- """
27
- Remove the background from an image and return both the transparent version and the original.
28
- This function performs background removal using a BiRefNet segmentation model. It is intended for use
29
- with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
30
- with the background removed, along with the original RGB version for comparison.
31
- Args:
32
- image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
33
- Returns:
34
- tuple:
35
- - processed_image (PIL.Image): The input image with the background removed and transparency applied.
36
- - origin (PIL.Image): The original RGB image, unchanged.
37
- """
38
- im = load_img(image, output_type="pil")
39
- im = im.convert("RGB")
40
- origin = im.copy()
41
- processed_image = process(im)
42
- return (processed_image, origin)
43
 
44
- @spaces.GPU
45
  def process(image: Image.Image) -> Image.Image:
46
- """
47
- Apply BiRefNet-based image segmentation to remove the background.
48
- This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
49
- and applies the mask as an alpha (transparency) channel to the original image.
50
- Args:
51
- image (PIL.Image): The input RGB image.
52
- Returns:
53
- PIL.Image: The image with the background removed, using the segmentation mask as transparency.
54
- """
55
  image_size = image.size
56
- input_images = transform_image(image).unsqueeze(0).to("cuda")
57
- # Prediction
58
  with torch.no_grad():
59
  preds = birefnet(input_images)[-1].sigmoid().cpu()
60
  pred = preds[0].squeeze()
@@ -63,17 +33,15 @@ def process(image: Image.Image) -> Image.Image:
63
  image.putalpha(mask)
64
  return image
65
 
 
 
 
 
 
 
66
  def process_file(f: str) -> str:
67
- """
68
- Load an image file from disk, remove the background, and save the output as a transparent PNG.
69
- Args:
70
- f (str): Filepath of the image to process.
71
- Returns:
72
- str: Path to the saved PNG image with background removed.
73
- """
74
  name_path = f.rsplit(".", 1)[0] + ".png"
75
- im = load_img(f, output_type="pil")
76
- im = im.convert("RGB")
77
  transparent = process(im)
78
  transparent.save(name_path)
79
  return name_path
@@ -85,17 +53,22 @@ image_file_upload = gr.Image(label="Upload an image", type="filepath")
85
  url_input = gr.Textbox(label="Paste an image URL")
86
  output_file = gr.File(label="Output PNG File")
87
 
88
- # Example images
89
  chameleon = load_img("butterfly.jpg", output_type="pil")
90
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
91
 
92
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
93
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
94
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
 
 
 
 
95
 
96
  demo = gr.TabbedInterface(
97
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
 
 
98
  )
99
 
100
  if __name__ == "__main__":
101
- demo.launch(show_error=True, mcp_server=True)
 
1
  import gradio as gr
2
  from loadimg import load_img
 
3
  from transformers import AutoModelForImageSegmentation
4
  import torch
5
  from torchvision import transforms
6
  from typing import Union, Tuple
7
  from PIL import Image
8
 
9
+ torch.set_float32_matmul_precision("high")
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  birefnet = AutoModelForImageSegmentation.from_pretrained(
14
  "ZhengPeng7/BiRefNet", trust_remote_code=True
15
  )
16
+ birefnet.to(device)
 
 
 
 
 
 
 
 
17
 
18
+ transform_image = transforms.Compose([
19
+ transforms.Resize((1024, 1024)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406],
22
+ [0.229, 0.224, 0.225]),
23
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  def process(image: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
26
  image_size = image.size
27
+ input_images = transform_image(image).unsqueeze(0).to(device)
 
28
  with torch.no_grad():
29
  preds = birefnet(input_images)[-1].sigmoid().cpu()
30
  pred = preds[0].squeeze()
 
33
  image.putalpha(mask)
34
  return image
35
 
36
+ def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
37
+ im = load_img(image, output_type="pil").convert("RGB")
38
+ origin = im.copy()
39
+ processed_image = process(im)
40
+ return processed_image, origin
41
+
42
  def process_file(f: str) -> str:
 
 
 
 
 
 
 
43
  name_path = f.rsplit(".", 1)[0] + ".png"
44
+ im = load_img(f, output_type="pil").convert("RGB")
 
45
  transparent = process(im)
46
  transparent.save(name_path)
47
  return name_path
 
53
  url_input = gr.Textbox(label="Paste an image URL")
54
  output_file = gr.File(label="Output PNG File")
55
 
 
56
  chameleon = load_img("butterfly.jpg", output_type="pil")
57
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
58
 
59
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1,
60
+ examples=[chameleon], api_name="image")
61
+ tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2,
62
+ examples=[url_example], api_name="text")
63
+ tab3 = gr.Interface(process_file, inputs=image_file_upload,
64
+ outputs=output_file, examples=["butterfly.jpg"],
65
+ api_name="png")
66
 
67
  demo = gr.TabbedInterface(
68
+ [tab1, tab2, tab3],
69
+ ["Image Upload", "URL Input", "File Output"],
70
+ title="Background Removal Tool"
71
  )
72
 
73
  if __name__ == "__main__":
74
+ demo.launch(show_error=True, server_name="0.0.0.0", server_port=7860)