Spaces:
Runtime error
Runtime error
File size: 1,204 Bytes
076bd8f ad54d7a 3c9f42d eb35177 3139aef 076bd8f 26ee91d 8029b4a ad54d7a 39fedb9 ad54d7a d3426a1 eb35177 83e14fc 37ae518 3470fd4 416769d 37ae518 b40acfb d3426a1 416769d ad54d7a d3426a1 ad54d7a d3426a1 ad54d7a 416769d d3426a1 076bd8f 42ae73b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
from test import inference_img
from models import *
import numpy as np
from PIL import Image
device='cpu'
model = StyleMatte()
model = model.to(device)
checkpoint = f"stylematte.pth"
state_dict = torch.load(checkpoint, map_location=f'{device}')
model.load_state_dict(state_dict)
model.eval()
def predict(inp):
print("***********Inference****************")
mask = inference_img(model, inp)
inp_np = np.array(inp)
fg = np.uint8((mask[:,:,None]*inp_np))
alpha_channel = (mask*255).astype(np.uint8)
print(fg.max(), alpha_channel.max(), fg.shape, alpha_channel.shape)
print("***********Inference finish****************")
# print("***********MASK****************", inp_np.max(), mask.max())
fg = np.dstack((fg, alpha_channel))
fg_pil = Image.fromarray(fg, 'RGBA')
return [mask, fg_pil]
print("MODEL LOADED")
print("************************************")
iface = gr.Interface(fn=predict,
inputs=gr.Image(type="numpy"),
outputs=[gr.Image(type="numpy"),gr.Image(type="pil", image_mode='RGBA')],
examples=["./logo.jpeg"])
print("****************Interface created******************")
iface.launch() |