salma-remyx commited on
Commit
5b7dfe2
·
1 Parent(s): 5614004

remove global

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -25,6 +25,27 @@ except OSError:
25
  download("en_core_web_sm")
26
  nlp = spacy.load("en_core_web_sm")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def find_subject(doc):
29
  for token in doc:
30
  # Check if the token is a subject
@@ -52,30 +73,20 @@ def caption_refiner(caption):
52
 
53
  @spaces.GPU
54
  def sam2(image, input_boxes, model_id="facebook/sam-vit-base"):
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
- model = SamModel.from_pretrained(model_id).to(device)
57
- processor = SamProcessor.from_pretrained(model_id)
58
- inputs = processor(image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
59
  with torch.no_grad():
60
- outputs = model(**inputs)
61
 
62
- masks = processor.image_processor.post_process_masks(
63
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
64
  )
65
  return masks
66
 
67
- @spaces.GPU
68
- def load_florence2(model_id="microsoft/Florence-2-base-ft", device='cuda'):
69
- torch_dtype = torch.float16 if device == 'cuda' else torch.float32
70
- florence_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
71
- florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
72
- return florence_model, florence_processor
73
 
74
  @spaces.GPU
75
  def florence2(image, prompt="", task="<OD>"):
76
- device = florence_model.device
77
  torch_dtype = florence_model.dtype
78
- inputs = florence_processor(text=task + prompt, images=image, return_tensors="pt").to(device, torch_dtype)
79
  generated_ids = florence_model.generate(
80
  input_ids=inputs["input_ids"],
81
  pixel_values=inputs["pixel_values"],
@@ -467,10 +478,5 @@ def build_demo():
467
  return demo
468
 
469
  if __name__ == "__main__":
470
- global model, transform, florence_model, florence_processor
471
- model, transform = depth_pro.create_model_and_transforms(device='cuda')
472
- florence_model, florence_processor = load_florence2(device='cuda')
473
-
474
-
475
  demo = build_demo()
476
  demo.launch(share=True)
 
25
  download("en_core_web_sm")
26
  nlp = spacy.load("en_core_web_sm")
27
 
28
+ # Load Florence and SAM models once at the top for reuse
29
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ def load_florence2(model_id="microsoft/Florence-2-base-ft", device=DEVICE):
32
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
33
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
34
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
35
+ return florence_model, florence_processor
36
+
37
+ florence_model, florence_processor = load_florence2() # Loaded globally for reuse
38
+
39
+ def load_sam_model(model_id="facebook/sam-vit-base", device=DEVICE):
40
+ sam_model = SamModel.from_pretrained(model_id).to(device)
41
+ sam_processor = SamProcessor.from_pretrained(model_id)
42
+ return sam_model, sam_processor
43
+
44
+ sam_model, sam_processor = load_sam_model() # Loaded globally for reuse
45
+
46
+ # Depth model, transform, and other assets
47
+ model, transform = depth_pro.create_model_and_transforms(device=DEVICE)
48
+
49
  def find_subject(doc):
50
  for token in doc:
51
  # Check if the token is a subject
 
73
 
74
  @spaces.GPU
75
  def sam2(image, input_boxes, model_id="facebook/sam-vit-base"):
76
+ inputs = sam_processor(image, input_boxes=[[input_boxes]], return_tensors="pt").to(DEVICE)
 
 
 
77
  with torch.no_grad():
78
+ outputs = sam_model(**inputs)
79
 
80
+ masks = sam_processor.image_processor.post_process_masks(
81
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
82
  )
83
  return masks
84
 
 
 
 
 
 
 
85
 
86
  @spaces.GPU
87
  def florence2(image, prompt="", task="<OD>"):
 
88
  torch_dtype = florence_model.dtype
89
+ inputs = florence_processor(text=task + prompt, images=image, return_tensors="pt").to(DEVICE, torch_dtype)
90
  generated_ids = florence_model.generate(
91
  input_ids=inputs["input_ids"],
92
  pixel_values=inputs["pixel_values"],
 
478
  return demo
479
 
480
  if __name__ == "__main__":
 
 
 
 
 
481
  demo = build_demo()
482
  demo.launch(share=True)