Sanket17 commited on
Commit
9f9625c
·
verified ·
1 Parent(s): eeac153

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -58
main.py CHANGED
@@ -1,77 +1,65 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
- from typing import Optional
5
  import base64
6
  import io
7
- from PIL import Image
8
- import torch
9
- import numpy as np
10
  import os
11
-
12
- # Existing imports
13
- import numpy as np
14
- import torch
15
  from PIL import Image
16
- import io
17
-
18
- from utils import (
19
- check_ocr_box,
20
- get_yolo_model,
21
- get_caption_model_processor,
22
- get_som_labeled_img,
23
- )
24
  import torch
25
-
26
- yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
27
- #caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="icon_caption_florence")
28
-
29
  from ultralytics import YOLO
 
30
 
 
31
  if not os.path.exists("weights/icon_detect"):
32
  os.makedirs("weights/icon_detect")
33
 
 
34
  try:
 
35
  yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
36
- except:
37
- yolo_model = YOLO("weights/icon_detect/best.pt")
38
-
39
- from transformers import AutoProcessor, AutoModelForCausalLM
40
-
41
- processor = AutoProcessor.from_pretrained(
42
- "microsoft/Florence-2-base", trust_remote_code=True
43
- )
44
 
 
45
  try:
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "microsoft/OmniParser",
48
  torch_dtype=torch.float16,
49
- trust_remote_code=True,
50
  ).to("cuda")
51
- except:
 
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  "microsoft/OmniParser",
54
  torch_dtype=torch.float16,
55
- trust_remote_code=True,
56
  )
 
57
  caption_model_processor = {"processor": processor, "model": model}
58
- print("finish loading model!!!")
59
 
 
60
  app = FastAPI()
61
 
62
-
63
  class ProcessResponse(BaseModel):
64
  image: str # Base64 encoded image
65
  parsed_content_list: str
66
  label_coordinates: str
67
 
68
-
69
  def process(
70
  image_input: Image.Image, box_threshold: float, iou_threshold: float
71
  ) -> ProcessResponse:
72
  image_save_path = "imgs/saved_image_demo.png"
73
  image_input.save(image_save_path)
74
  image = Image.open(image_save_path)
 
 
75
  box_overlay_ratio = image.size[0] / 3200
76
  draw_bbox_config = {
77
  "text_scale": 0.8 * box_overlay_ratio,
@@ -80,30 +68,40 @@ def process(
80
  "thickness": max(int(3 * box_overlay_ratio), 1),
81
  }
82
 
83
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
- image_save_path,
85
- display_img=False,
86
- output_bb_format="xyxy",
87
- goal_filtering=None,
88
- easyocr_args={"paragraph": False, "text_threshold": 0.9},
89
- use_paddleocr=True,
90
- )
91
- text, ocr_bbox = ocr_bbox_rslt
92
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
- image_save_path,
94
- yolo_model,
95
- BOX_TRESHOLD=box_threshold,
96
- output_coord_in_ratio=True,
97
- ocr_bbox=ocr_bbox,
98
- draw_bbox_config=draw_bbox_config,
99
- caption_model_processor=caption_model_processor,
100
- ocr_text=text,
101
- iou_threshold=iou_threshold,
102
- )
 
 
 
 
 
 
 
 
 
 
 
103
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
- print("finish processing")
105
  parsed_content_list_str = "\n".join(parsed_content_list)
106
-
107
  # Encode image to base64
108
  buffered = io.BytesIO()
109
  image.save(buffered, format="PNG")
@@ -115,7 +113,7 @@ def process(
115
  label_coordinates=str(label_coordinates),
116
  )
117
 
118
-
119
  @app.post("/process_image", response_model=ProcessResponse)
120
  async def process_image(
121
  image_file: UploadFile = File(...),
@@ -126,7 +124,8 @@ async def process_image(
126
  contents = await image_file.read()
127
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
128
  except Exception as e:
129
- raise HTTPException(status_code=400, detail="Invalid image file")
130
 
 
131
  response = process(image_input, box_threshold, iou_threshold)
132
  return response
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
 
4
  import base64
5
  import io
 
 
 
6
  import os
 
 
 
 
7
  from PIL import Image
 
 
 
 
 
 
 
 
8
  import torch
9
+ import numpy as np
 
 
 
10
  from ultralytics import YOLO
11
+ from transformers import AutoProcessor, AutoModelForCausalLM
12
 
13
+ # Ensure directories exist
14
  if not os.path.exists("weights/icon_detect"):
15
  os.makedirs("weights/icon_detect")
16
 
17
+ # Model loading with error handling
18
  try:
19
+ # Load YOLO model
20
  yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
21
+ except Exception as e:
22
+ print(f"Error loading YOLO model: {e}")
23
+ yolo_model = YOLO("weights/icon_detect/best.pt") # Load on CPU if CUDA fails
 
 
 
 
 
24
 
25
+ # Load Caption Model (Florence and OmniParser)
26
  try:
27
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  "microsoft/OmniParser",
30
  torch_dtype=torch.float16,
31
+ trust_remote_code=True
32
  ).to("cuda")
33
+ except Exception as e:
34
+ print(f"Error loading caption model: {e}")
35
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
36
  model = AutoModelForCausalLM.from_pretrained(
37
  "microsoft/OmniParser",
38
  torch_dtype=torch.float16,
39
+ trust_remote_code=True
40
  )
41
+
42
  caption_model_processor = {"processor": processor, "model": model}
43
+ print("Finished loading models!")
44
 
45
+ # FastAPI app initialization
46
  app = FastAPI()
47
 
48
+ # Pydantic response model
49
  class ProcessResponse(BaseModel):
50
  image: str # Base64 encoded image
51
  parsed_content_list: str
52
  label_coordinates: str
53
 
54
+ # Function to process the image, apply YOLO, and generate captions
55
  def process(
56
  image_input: Image.Image, box_threshold: float, iou_threshold: float
57
  ) -> ProcessResponse:
58
  image_save_path = "imgs/saved_image_demo.png"
59
  image_input.save(image_save_path)
60
  image = Image.open(image_save_path)
61
+
62
+ # Ratio for bounding box scaling
63
  box_overlay_ratio = image.size[0] / 3200
64
  draw_bbox_config = {
65
  "text_scale": 0.8 * box_overlay_ratio,
 
68
  "thickness": max(int(3 * box_overlay_ratio), 1),
69
  }
70
 
71
+ # OCR Box Detection and Filtering (using EasyOCR and PaddleOCR)
72
+ try:
73
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
74
+ image_save_path,
75
+ display_img=False,
76
+ output_bb_format="xyxy",
77
+ goal_filtering=None,
78
+ easyocr_args={"paragraph": False, "text_threshold": 0.9},
79
+ use_paddleocr=True,
80
+ )
81
+ text, ocr_bbox = ocr_bbox_rslt
82
+ except Exception as e:
83
+ raise HTTPException(status_code=500, detail=f"OCR processing failed: {e}")
84
+
85
+ # YOLO and Caption Model Inference
86
+ try:
87
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
88
+ image_save_path,
89
+ yolo_model,
90
+ BOX_TRESHOLD=box_threshold,
91
+ output_coord_in_ratio=True,
92
+ ocr_bbox=ocr_bbox,
93
+ draw_bbox_config=draw_bbox_config,
94
+ caption_model_processor=caption_model_processor,
95
+ ocr_text=text,
96
+ iou_threshold=iou_threshold,
97
+ )
98
+ except Exception as e:
99
+ raise HTTPException(status_code=500, detail=f"YOLO or caption model inference failed: {e}")
100
+
101
+ # Convert processed image to base64
102
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
 
103
  parsed_content_list_str = "\n".join(parsed_content_list)
104
+
105
  # Encode image to base64
106
  buffered = io.BytesIO()
107
  image.save(buffered, format="PNG")
 
113
  label_coordinates=str(label_coordinates),
114
  )
115
 
116
+ # FastAPI route to process uploaded image
117
  @app.post("/process_image", response_model=ProcessResponse)
118
  async def process_image(
119
  image_file: UploadFile = File(...),
 
124
  contents = await image_file.read()
125
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
126
  except Exception as e:
127
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
128
 
129
+ # Process the image
130
  response = process(image_input, box_threshold, iou_threshold)
131
  return response