heuue commited on
Commit
62ed246
·
verified ·
1 Parent(s): fe4812a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -4
app.py CHANGED
@@ -1,7 +1,51 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+
2
  import gradio as gr
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+
8
+ # 加载模型
9
+ class SimpleModel(torch.nn.Module):
10
+ def __init__(self, num_classes=3):
11
+ super(SimpleModel, self).__init__()
12
+ self.fc = torch.nn.Linear(3 * 224 * 224, num_classes)
13
+
14
+ def forward(self, x):
15
+ x = x.view(x.size(0), -1)
16
+ return self.fc(x)
17
+
18
+ # 初始化模型并加载权重
19
+ model = SimpleModel(num_classes=3)
20
+ model_path = "model.pth"
21
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
22
+ model.eval()
23
+
24
+
25
+ def preprocess(image: Image.Image):
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
+ ])
31
+ return transform(image).unsqueeze(0)
32
+
33
+
34
+ def classify_image(image: Image.Image):
35
+ input_tensor = preprocess(image)
36
+ with torch.no_grad():
37
+ output = model(input_tensor)
38
+ probabilities = F.softmax(output[0], dim=0)
39
+ top_class = probabilities.argmax().item()
40
+ return f"Predicted: {class_names[top_class]} (Confidence: {probabilities[top_class]:.2f})"
41
+
42
 
43
+ app = gr.Interface(
44
+ fn=classify_image, # 推理函数
45
+ inputs=gr.Image(type="pil"), # 接收输入图片,返回 PIL 格式
46
+ outputs="text", # 输出分类结果
47
+ title="Image Classification with PyTorch",
48
+ description="Upload an image to classify it using a pre-trained PyTorch model."
49
+ )
50
 
51
+ app.launch()