pose_experiment / app.py
yijiu's picture
feat: upload project
93a6fff
raw
history blame
1.71 kB
import gradio as gr
import time
import numpy
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch
import skimage
from models.hr_net import hr_w32
from tool_utils import heatmaps_to_coords,draw_joints
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#Create example list from 'examples/'directory
example_list=[["./examples/"+example] for example in os.listdir("examples")]
def predict(numpy_img):
#resize the numpy_image size to (256,256)
img_np=skimage.transform.resize(numpy_img,[256,256])
#convert numpy_image to tensor
img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device)
#choose model class hr_w32
model=hr_w32().to(device)
#load weights of model
model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth')['model'])
# #set model to pred state
model.eval()
# #predict the heatmaps of joints
start_time=time.time()
heatmaps_pred=model(img)
heatmaps_pred=heatmaps_pred.double()
# #convert output to numpy
heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
# #heatmaps to joints location
coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[256,256],prob_threshold=0.1)
inference_time=time.time()-start_time
inference_time_text="model inference time:{:.4f}s".format(inference_time)
# #draw coords on image_np
img_rgb=draw_joints(img_np,coord_joints)
return img_rgb,inference_time_text
demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy',width=256,height=256),"text"],examples=example_list)
if __name__=="__main__":
demo.launch(show_api=False)