Fediory commited on
Commit
d86981d
·
1 Parent(s): df8ee0b

fix: different safetensors

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -5,11 +5,27 @@ from PIL import Image
5
  from net.CIDNet import CIDNet
6
  import torchvision.transforms as transforms
7
  import torch.nn.functional as F
8
- import os
9
  import imquality.brisque as brisque
10
  from loss.niqe_utils import *
11
  import spaces
12
- import huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  eval_net = CIDNet().cuda()
15
  eval_net.trans.gated = True
@@ -20,7 +36,7 @@ def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
20
  if model_path is None:
21
  return input_img,"Please choose a model weights."
22
  torch.set_grad_enabled(False)
23
- eval_net.from_pretrained("Fediory/HVI-CIDNet-"+model_path)
24
  # eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
25
  eval_net.eval()
26
 
@@ -81,4 +97,4 @@ interface = gr.Interface(
81
  allow_flagging="never"
82
  )
83
 
84
- interface.launch(share=True)
 
5
  from net.CIDNet import CIDNet
6
  import torchvision.transforms as transforms
7
  import torch.nn.functional as F
8
+ import safetensors.torch as sf
9
  import imquality.brisque as brisque
10
  from loss.niqe_utils import *
11
  import spaces
12
+ from huggingface_hub import hf_hub_download
13
+ import json
14
+
15
+ def from_pretrained(cls, pretrained_model_name_or_path: str):
16
+ model_id = str(pretrained_model_name_or_path)
17
+
18
+ config_file = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
19
+ config = None
20
+ if config_file is not None:
21
+ with open(config_file, "r", encoding="utf-8") as f:
22
+ config = json.load(f)
23
+
24
+
25
+ model_file = hf_hub_download(repo_id=model_id, filename="model.safetensors", repo_type="model")
26
+ # instance = sf.load_model(cls, model_file, strict=False)
27
+ state_dict = sf.load_file(model_file)
28
+ cls.load_state_dict(state_dict, strict=False)
29
 
30
  eval_net = CIDNet().cuda()
31
  eval_net.trans.gated = True
 
36
  if model_path is None:
37
  return input_img,"Please choose a model weights."
38
  torch.set_grad_enabled(False)
39
+ from_pretrained(eval_net,"Fediory/HVI-CIDNet-"+model_path)
40
  # eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
41
  eval_net.eval()
42
 
 
97
  allow_flagging="never"
98
  )
99
 
100
+ interface.launch(share=False)