zino36 commited on
Commit
2b8834a
·
verified ·
1 Parent(s): 7a6b748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -32
app.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import tempfile
12
  from gradio_imageslider import ImageSlider
13
  from huggingface_hub import hf_hub_download
14
- import safetensors
15
 
16
  from depth_anything_v2.dpt import DepthAnythingV2
17
 
@@ -47,39 +47,9 @@ model_name = encoder2name[encoder]
47
  model = DepthAnythingV2(**model_configs[encoder])
48
  filepath = hf_hub_download(repo_id="depth-anything/Depth-Anything-V2-Metric-Indoor-Large-hf", filename="model.safetensors", repo_type="model")
49
 
50
- def create_tensor(storage, info, offset):
51
- DTYPES = {"F32": torch.float32}
52
- dtype = DTYPES[info["dtype"]]
53
- shape = info["shape"]
54
- start, stop = info["data_offsets"]
55
- return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
56
-
57
- def load_file(filename):
58
- with open(filename, mode="r", encoding="utf8") as file_obj:
59
- with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
60
- header = m.read(8)
61
- n = int.from_bytes(header, "little")
62
- metadata_bytes = m.read(n)
63
- metadata = json.loads(metadata_bytes)
64
-
65
- size = os.stat(filename).st_size
66
- storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
67
- offset = n + 8
68
- return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
69
-
70
- tensor_data = safetensors.load(filepath)
71
-
72
  # Convert to PyTorch tensor
73
- if isinstance(tensor_data, np.ndarray):
74
- pytorch_tensor = torch.tensor(tensor_data)
75
- elif isinstance(tensor_data, safetensors.Tensor):
76
- pytorch_tensor = torch.tensor(tensor_data.numpy()) # Assuming safetensors Tensor has a .numpy() method
77
- else:
78
- raise TypeError("Unsupported data type from safetensors")
79
-
80
  #state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
81
- #state_dict = load_file(filepath)
82
- state_dict = pytorch_tensor
83
 
84
  model.load_state_dict(state_dict)
85
  model = model.to(DEVICE).eval()
 
11
  import tempfile
12
  from gradio_imageslider import ImageSlider
13
  from huggingface_hub import hf_hub_download
14
+ from safetensors.torch import load_file
15
 
16
  from depth_anything_v2.dpt import DepthAnythingV2
17
 
 
47
  model = DepthAnythingV2(**model_configs[encoder])
48
  filepath = hf_hub_download(repo_id="depth-anything/Depth-Anything-V2-Metric-Indoor-Large-hf", filename="model.safetensors", repo_type="model")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Convert to PyTorch tensor
 
 
 
 
 
 
 
51
  #state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
52
+ state_dict = load_file(filepath)
 
53
 
54
  model.load_state_dict(state_dict)
55
  model = model.to(DEVICE).eval()