attilaultzindur commited on
Commit
d373c9b
·
1 Parent(s): 2a01b44

initial model push

Browse files
Files changed (2) hide show
  1. README.md +0 -10
  2. handler.py +3 -1
README.md CHANGED
@@ -1,13 +1,3 @@
1
- ---
2
- library_name: torchvision # hangi framework’ü kullandığı
3
- pipeline_tag: image-classification # görev türü (widget için şart)
4
- tags:
5
- - image-classification
6
- - efficientnet
7
- - garbage
8
- metrics:
9
- - accuracy
10
- ---
11
  # Garbage Classifier · EfficientNet‑V2‑S (torchvision)
12
 
13
  Finetuned model for 10‑class garbage image classification.
 
 
 
 
 
 
 
 
 
 
 
1
  # Garbage Classifier · EfficientNet‑V2‑S (torchvision)
2
 
3
  Finetuned model for 10‑class garbage image classification.
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Dict, Any
2
  import io, base64, torch, torchvision
 
3
  from PIL import Image
4
  from torchvision import transforms as T
5
 
@@ -14,7 +15,8 @@ class EndpointHandler:
14
  torch.nn.Dropout(0.5),
15
  torch.nn.Linear(256, len(self.labels))
16
  )
17
- self.model.load_state_dict(torch.load(f"{path}/model.safetensors", map_location="cpu"))
 
18
  self.model.eval()
19
  self.trans = T.Compose([
20
  T.Resize((224, 224)),
 
1
  from typing import Dict, Any
2
  import io, base64, torch, torchvision
3
+ from safetensors.torch import load_file
4
  from PIL import Image
5
  from torchvision import transforms as T
6
 
 
15
  torch.nn.Dropout(0.5),
16
  torch.nn.Linear(256, len(self.labels))
17
  )
18
+ state = load_file(str(pth), device="cpu")
19
+ self.model.load_state_dict(state)
20
  self.model.eval()
21
  self.trans = T.Compose([
22
  T.Resize((224, 224)),