lxy1122 commited on
Commit
3812f5d
·
1 Parent(s): 18efee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,16 +1,21 @@
1
  import torch
2
  from PIL import Image
3
- from torchvision import transforms
4
  import gradio as gr
5
  import os
6
 
7
 
8
  os.system("wget https://github.com/liuxiaoyuyuyu/vanGogh-and-Other-Artist/blob/main/artist_classes.txt")
 
9
 
10
- model = torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True)
11
  #checkpoint = 'https://github.com/liuxiaoyuyuyu/vanGogh-and-Other-Artist/blob/main/model_weights_mobilenet_v2_valp1trainp2.pth'
12
  #model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
13
- model.eval()
 
 
 
 
14
 
15
  #torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
16
 
 
1
  import torch
2
  from PIL import Image
3
+ from torchvision import datasets, models, transforms
4
  import gradio as gr
5
  import os
6
 
7
 
8
  os.system("wget https://github.com/liuxiaoyuyuyu/vanGogh-and-Other-Artist/blob/main/artist_classes.txt")
9
+ os.system("wget https://github.com/liuxiaoyuyuyu/vanGogh-and-Other-Artist/blob/main/model_weights_mobilenet_v2_valp1trainp2.pth")
10
 
11
+ #model = torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=False)
12
  #checkpoint = 'https://github.com/liuxiaoyuyuyu/vanGogh-and-Other-Artist/blob/main/model_weights_mobilenet_v2_valp1trainp2.pth'
13
  #model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
14
+ model = models.mobilenet_v2()
15
+ num_ftrs = model.classifier[1].in_features
16
+ model.classifier[1] = nn.Linear(num_ftrs, 6)
17
+ model = model.to(device)
18
+ model.load_state_dict(torch.load('model_weights_mobilenet_v2_valp1trainp2.pth'))
19
 
20
  #torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
21