faelfernandes commited on
Commit
58a5713
·
verified ·
1 Parent(s): dafe74e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import gradio as gr
6
+ from gradio_imageslider import ImageSlider
7
+ from briarmbg import BriaRMBG
8
+ import PIL
9
+ from PIL import Image
10
+ from typing import Tuple
11
+
12
+
13
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ net.to(device)
16
+
17
+
18
+ def resize_image(image):
19
+ image = image.convert('RGB')
20
+ model_input_size = (1024, 1024)
21
+ image = image.resize(model_input_size, Image.BILINEAR)
22
+ return image
23
+
24
+
25
+ def process(image):
26
+
27
+ # prepare input
28
+ orig_image = Image.fromarray(image)
29
+ w,h = orig_im_size = orig_image.size
30
+ image = resize_image(orig_image)
31
+ im_np = np.array(image)
32
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
33
+ im_tensor = torch.unsqueeze(im_tensor,0)
34
+ im_tensor = torch.divide(im_tensor,255.0)
35
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
36
+ if torch.cuda.is_available():
37
+ im_tensor=im_tensor.cuda()
38
+
39
+ #inference
40
+ result=net(im_tensor)
41
+ # post process
42
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
43
+ ma = torch.max(result)
44
+ mi = torch.min(result)
45
+ result = (result-mi)/(ma-mi)
46
+ # image to pil
47
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
48
+ pil_im = Image.fromarray(np.squeeze(im_array))
49
+ # paste the mask on the original image
50
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
+ new_im.paste(orig_image, mask=pil_im)
52
+ # new_orig_image = orig_image.convert('RGBA')
53
+
54
+ return new_im
55
+ # return [new_orig_image, new_im]
56
+
57
+
58
+ # block = gr.Blocks().queue()
59
+
60
+ # with block:
61
+ # gr.Markdown("## BRIA RMBG 1.4")
62
+ # gr.HTML('''
63
+ # <p style="margin-bottom: 10px; font-size: 94%">
64
+ # This is a demo for BRIA RMBG 1.4 that using
65
+ # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
66
+ # </p>
67
+ # ''')
68
+ # with gr.Row():
69
+ # with gr.Column():
70
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
71
+ # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
72
+ # run_button = gr.Button(value="Run")
73
+
74
+ # with gr.Column():
75
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
76
+ # ips = [input_image]
77
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
78
+
79
+ # block.launch(debug = True)
80
+
81
+ # block = gr.Blocks().queue()
82
+
83
+ gr.Markdown("## BRIA RMBG 1.4")
84
+ gr.HTML('''
85
+ <p style="margin-bottom: 10px; font-size: 94%">
86
+ This is a demo for BRIA RMBG 1.4 that using
87
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
88
+ </p>
89
+ ''')
90
+ title = "Background Removal"
91
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
92
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
93
+ """
94
+ examples = [['./input.jpg'],]
95
+ # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
96
+ # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
97
+ demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch(share=False)