0notexist0 commited on
Commit
1dfa7f1
·
verified ·
1 Parent(s): c437df7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -69
app.py CHANGED
@@ -1,74 +1,18 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- from transformers import AutoModelForImageSegmentation
4
- import torch
5
- from torchvision import transforms
6
- from typing import Union, Tuple
7
  from PIL import Image
 
8
 
9
- torch.set_float32_matmul_precision("high")
10
-
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- birefnet = AutoModelForImageSegmentation.from_pretrained(
14
- "ZhengPeng7/BiRefNet", trust_remote_code=True
15
- )
16
- birefnet.to(device)
17
-
18
- transform_image = transforms.Compose([
19
- transforms.Resize((1024, 1024)),
20
- transforms.ToTensor(),
21
- transforms.Normalize([0.485, 0.456, 0.406],
22
- [0.229, 0.224, 0.225]),
23
- ])
24
-
25
- def process(image: Image.Image) -> Image.Image:
26
- image_size = image.size
27
- input_images = transform_image(image).unsqueeze(0).to(device)
28
- with torch.no_grad():
29
- preds = birefnet(input_images)[-1].sigmoid().cpu()
30
- pred = preds[0].squeeze()
31
- pred_pil = transforms.ToPILImage()(pred)
32
- mask = pred_pil.resize(image_size)
33
- image.putalpha(mask)
34
- return image
35
-
36
- def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
37
- im = load_img(image, output_type="pil").convert("RGB")
38
- origin = im.copy()
39
- processed_image = process(im)
40
- return processed_image, origin
41
-
42
- def process_file(f: str) -> str:
43
- name_path = f.rsplit(".", 1)[0] + ".png"
44
- im = load_img(f, output_type="pil").convert("RGB")
45
- transparent = process(im)
46
- transparent.save(name_path)
47
- return name_path
48
-
49
- slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
50
- slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
51
- image_upload = gr.Image(label="Upload an image")
52
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
53
- url_input = gr.Textbox(label="Paste an image URL")
54
- output_file = gr.File(label="Output PNG File")
55
-
56
- chameleon = load_img("butterfly.jpg", output_type="pil")
57
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
58
 
59
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1,
60
- examples=[chameleon], api_name="image")
61
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2,
62
- examples=[url_example], api_name="text")
63
- tab3 = gr.Interface(process_file, inputs=image_file_upload,
64
- outputs=output_file, examples=["butterfly.jpg"],
65
- api_name="png")
66
 
67
- demo = gr.TabbedInterface(
68
- [tab1, tab2, tab3],
69
- ["Image Upload", "URL Input", "File Output"],
70
- title="Background Removal Tool"
71
- )
72
 
73
- if __name__ == "__main__":
74
- demo.launch(show_error=True, server_name="0.0.0.0", server_port=7860)
 
1
+ from transformers import AutoProcessor, AutoModelForImageSegmentation
 
 
 
 
 
2
  from PIL import Image
3
+ import torch
4
 
5
+ # Carica modello e processor
6
+ processor = AutoProcessor.from_pretrained("BritishWerewolf/U-2-Netp")
7
+ model = AutoModelForImageSegmentation.from_pretrained("BritishWerewolf/U-2-Netp")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Prepara l’immagine
10
+ img = Image.open("input.jpg").convert("RGB")
11
+ inputs = processor(images=img, return_tensors="pt")
 
 
 
 
12
 
13
+ # Inferenzia maschera
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ mask = outputs.logits.argmax(dim=1)[0].cpu().numpy()
 
17
 
18
+ # Applica maschera all’immagine originale...