|
import gradio as gr |
|
import torch |
|
from carvekit.api.interface import Interface |
|
from carvekit.ml.wrap.basnet import BASNET |
|
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 |
|
from carvekit.ml.wrap.fba_matting import FBAMatting |
|
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 |
|
from carvekit.ml.wrap.u2net import U2NET |
|
from carvekit.pipelines.postprocessing import MattingMethod |
|
from carvekit.pipelines.preprocessing import PreprocessingStub |
|
from carvekit.trimap.generator import TrimapGenerator |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
segment_net = { |
|
"U2NET": U2NET(device=device, batch_size=1), |
|
"BASNET": BASNET(device=device, batch_size=1), |
|
"DeepLabV3": DeepLabV3(device=device, batch_size=1), |
|
"TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1) |
|
} |
|
|
|
fba = FBAMatting(device=device, |
|
input_tensor_size=2048, |
|
batch_size=1) |
|
|
|
trimap = TrimapGenerator() |
|
|
|
preprocessing = PreprocessingStub() |
|
|
|
postprocessing = MattingMethod(matting_module=fba, |
|
trimap_generator=trimap, |
|
device=device) |
|
|
|
method_choices = [k for k, v in segment_net.items()] |
|
|
|
|
|
def generate_trimap(method, original): |
|
mask = segment_net[method]([original]) |
|
return trimap(original_image=original, mask=mask[0]) |
|
|
|
|
|
def predict(method, image): |
|
method = segment_net[method] |
|
return Interface(pre_pipe=preprocessing, |
|
post_pipe=postprocessing, |
|
seg_pipe=method)([image])[0] |
|
|
|
|
|
footer = r""" |
|
<center> |
|
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80"> |
|
</br> |
|
<b> |
|
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a> |
|
</b> |
|
</center> |
|
""" |
|
|
|
with gr.Blocks(title="CarveKit") as app: |
|
gr.Markdown("<center><h1><b>CarveKit</b></h1></center>") |
|
gr.HTML("<center><h3>High-quality image background removal</h3></center>") |
|
|
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("Remove background", id=0): |
|
with gr.Row(equal_height=False): |
|
with gr.Column(): |
|
input_img = gr.Image(type="pil", label="Input image") |
|
drp_itf = gr.Dropdown( |
|
value="TracerUniversalB7", |
|
label="Segmentor model", |
|
choices=method_choices) |
|
run_btn = gr.Button(variant="primary") |
|
with gr.Column(): |
|
output_img = gr.Image(type="pil", label="result") |
|
|
|
run_btn.click(predict, [drp_itf, input_img], [output_img]) |
|
|
|
with gr.TabItem("Trimap generator", id=1): |
|
with gr.Row(equal_height=False): |
|
with gr.Column(): |
|
trimap_input = gr.Image(type="pil", label="Input image") |
|
drp_itf = gr.Dropdown( |
|
value="TracerUniversalB7", |
|
label="Segmentor model", |
|
choices=method_choices) |
|
trimap_btn = gr.Button(variant="primary") |
|
with gr.Column(): |
|
trimap_output = gr.Image(type="pil", label="result") |
|
|
|
trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output]) |
|
|
|
with gr.Row(): |
|
gr.HTML(footer) |
|
|
|
app.queue() |
|
app.launch(share=False, debug=True, show_error=True) |
|
|