Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -493,13 +493,15 @@ from torchvision import transforms, utils
|
|
| 493 |
from matplotlib import pyplot as plt
|
| 494 |
import numpy as np
|
| 495 |
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
def predict_pose(test_image):
|
| 498 |
img = cv2.resize(test_image, (32,32))
|
| 499 |
convert_tensor = transforms.ToTensor()
|
| 500 |
tensor_img = convert_tensor(img)
|
| 501 |
tensor_img = tensor_img[None,:,:,:]
|
| 502 |
-
model.eval()
|
| 503 |
|
| 504 |
outputs = model(tensor_img)
|
| 505 |
|
|
|
|
| 493 |
from matplotlib import pyplot as plt
|
| 494 |
import numpy as np
|
| 495 |
|
| 496 |
+
model = SimpleCNN()
|
| 497 |
+
model.load_state_dict(torch.load("model.pth"))
|
| 498 |
+
model.eval()
|
| 499 |
|
| 500 |
def predict_pose(test_image):
|
| 501 |
img = cv2.resize(test_image, (32,32))
|
| 502 |
convert_tensor = transforms.ToTensor()
|
| 503 |
tensor_img = convert_tensor(img)
|
| 504 |
tensor_img = tensor_img[None,:,:,:]
|
|
|
|
| 505 |
|
| 506 |
outputs = model(tensor_img)
|
| 507 |
|