File size: 1,447 Bytes
64e7562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from datetime import date
import jittor as jt
from numpy.core.fromnumeric import shape
from numpy.lib.type_check import imag
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import cv2
import os


def ImageClassification(img_path, model):
    # 得到一个 HxWx3 的 array(224, 225, 3)
    image_start = cv2.imread(img_path)
    # 把图像缩放到 28x28 个像素(28, 28, 3)
    image = cv2.resize(image_start, (28, 28))
    # print(image.shape)
    image = image / 255.0               # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
    image = image.transpose(2, 0, 1)    # 把输入格式从 HWC 改为 CHW
    image = jt.float32(image)           # 变为 Jittor Var
    image_end = image.unsqueeze(dim=0)      # 加入 batch 维度,变为 [1, C, H, W]
    outputs = model(image_end)
    prediction = np.argmax(outputs.data, axis=1)
    # TODO 展示图片
    cv2.imshow('MNISt', image_start)
    cv2.waitKey(0)
    print('图片识别结果:'+str(prediction[0]))


def main():
    pwd_path = os.path.abspath(os.path.dirname(__file__))
    save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
    # TODO 加载模型
    model = Model()
    model.load_parameters(jt.load(save_model_path))
    # TODO 加载本地图片
    img_path = '0.jpg'
    # TODO 对图片进行识别
    ImageClassification(img_path, model)


if __name__ == '__main__':
    main()