File size: 1,751 Bytes
c5890a7
93a6fff
 
 
 
 
 
 
 
 
c5890a7
93a6fff
 
 
c5890a7
93a6fff
 
 
 
 
 
 
 
44207e0
93a6fff
 
 
 
 
 
 
 
 
383ca9f
93a6fff
 
 
383ca9f
93a6fff
 
 
 
383ca9f
93a6fff
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import time
import numpy
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch
import skimage
from models.hr_net import hr_w32
from tool_utils import heatmaps_to_coords,draw_joints

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#Create example list from 'examples/'directory
example_list=[["./examples/"+example] for example in os.listdir("examples")]

def predict(numpy_img):
    #resize the numpy_image size to (256,256)
    img_np=skimage.transform.resize(numpy_img,[256,256])
    #convert numpy_image to tensor
    img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device)
    #choose model class hr_w32
    model=hr_w32().to(device)
    #load weights of model
    model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth',map_location=torch.device('cpu'))['model'])
    # #set model to pred state
    model.eval()
    # #predict the heatmaps of joints
    start_time=time.time()
    heatmaps_pred=model(img)
    heatmaps_pred=heatmaps_pred.double()
    # #convert output to numpy
    heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
    # #heatmaps to joints location
    coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[numpy_img.shape[0],numpy_img.shape[1]],prob_threshold=0.1)
    inference_time=time.time()-start_time
    inference_time_text="model inference time:{:.4f}s".format(inference_time)
    # #draw coords on image_np
    img_rgb=draw_joints(numpy_img,coord_joints)
    return img_rgb,inference_time_text
    
    

demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy'),"text"],examples=example_list)

if __name__=="__main__":
    demo.launch(show_api=False)