John Ho commited on
Commit
7afaf9e
·
1 Parent(s): d81f6c9

trying to debug issue with F.scaled_dot_product_attention

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -5,13 +5,11 @@ from tqdm import tqdm
5
  from samv2_handler import load_sam_image_model, run_sam_im_inference
6
  from PIL import Image
7
  from typing import Union
8
- import subprocess
9
 
10
- subprocess.run(
11
- "pip install flash-attn --no-build-isolation",
12
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
13
- shell=True,
14
- )
15
 
16
 
17
  def download_checkpoints():
@@ -52,6 +50,8 @@ def load_im_model(variant, auto_mask_gen: bool = False):
52
 
53
 
54
  @spaces.GPU
 
 
55
  def detect_image(
56
  im: Image.Image,
57
  variant: str,
 
5
  from samv2_handler import load_sam_image_model, run_sam_im_inference
6
  from PIL import Image
7
  from typing import Union
 
8
 
9
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
10
+ if torch.cuda.get_device_properties(0).major >= 8:
11
+ torch.backends.cuda.matmul.allow_tf32 = True
12
+ torch.backends.cudnn.allow_tf32 = True
 
13
 
14
 
15
  def download_checkpoints():
 
50
 
51
 
52
  @spaces.GPU
53
+ @torch.inference_mode()
54
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
55
  def detect_image(
56
  im: Image.Image,
57
  variant: str,