merve HF Staff qubvel-hf HF Staff commited on
Commit
711aaa4
·
verified ·
1 Parent(s): e37bc45

Slice register tokens (#2)

Browse files

- Slice register tokens (221f4602c71290ae688ef1d7856618da24ebbf6e)


Co-authored-by: Pavel Iakubovskii <[email protected]>

Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -35,7 +35,8 @@ def extract_features(image, model_name):
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
  features = outputs.last_hidden_state
38
- return features[:, 1:, :].float().cpu(), original_size, model_size
 
39
 
40
  def find_correspondences(features1, features2, threshold=0.8):
41
  device = torch.device("cpu")
 
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
  features = outputs.last_hidden_state
38
+ num_register_tokens = getattr(model.config, "num_register_tokens", 0)
39
+ return features[:, 1 + num_register_tokens:, :].float().cpu(), original_size, model_size
40
 
41
  def find_correspondences(features1, features2, threshold=0.8):
42
  device = torch.device("cpu")