File size: 3,592 Bytes
19d935b
 
 
60112c7
d865512
19d935b
 
 
60112c7
19d935b
 
 
 
 
60112c7
19d935b
 
 
 
 
 
fed4460
d865512
 
19d935b
d865512
fed4460
d865512
 
19d935b
60112c7
19d935b
 
 
 
d865512
19d935b
 
d865512
 
19d935b
 
d865512
 
19d935b
d865512
19d935b
 
d865512
 
19d935b
 
 
 
 
 
 
 
d865512
 
 
 
19d935b
 
 
 
 
60112c7
19d935b
 
 
 
 
 
 
 
 
 
 
 
 
 
60112c7
 
19d935b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d865512
19d935b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
import torch
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
from PIL import Image
import io
import base64

app = FastAPI(
    title="OmniParser API",
    description="API for parsing GUI elements from images",
    version="1.0.0"
)

# Model class
class OmniParser:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Initialize processor and model
        self.processor = AutoProcessor.from_pretrained(
            "microsoft/Florence-2-base",
            trust_remote_code=True,
            cache_dir="/code/.cache"
        )
        self.model = AutoModelForVisualQuestionAnswering.from_pretrained(
            "microsoft/OmniParser/icon_caption_florence",
            trust_remote_code=True,
            cache_dir="/code/.cache"
        ).to(self.device)

    @torch.inference_mode()
    def process_image(
        self,
        image: Image.Image,
        question: str = "What elements do you see in this GUI?",
    ) -> Dict[str, Any]:
        # Process image with the model
        inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs)
        
        # Decode the outputs
        predicted_answer = self.processor.decode(
            outputs.logits.argmax(-1)[0], 
            skip_special_tokens=True
        )
        
        return {
            "parsed_elements": predicted_answer,
            "box_coordinates": {}  # Placeholder for future box detection implementation
        }

# Initialize model
model = OmniParser()

# Request/Response models
class ParseRequest(BaseModel):
    image_data: str = Field(..., description="Base64 encoded image data")
    question: Optional[str] = Field(
        default="What elements do you see in this GUI?",
        description="Question to ask about the GUI"
    )

class ParseResponse(BaseModel):
    parsed_elements: str
    box_coordinates: dict
    output_image: Optional[str]

def load_and_preprocess_image(image_data: bytes) -> Optional[Image.Image]:
    """Load and preprocess image from bytes data."""
    try:
        image = Image.open(io.BytesIO(image_data))
        return image
    except Exception as e:
        raise ValueError(f"Failed to load image: {str(e)}")

def encode_output_image(image: Image.Image) -> str:
    """Encode PIL Image to base64 string."""
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

@app.get("/")
async def root():
    return {
        "message": "OmniParser API is running",
        "docs_url": "/docs"
    }

@app.post("/parse", response_model=ParseResponse)
async def parse_image(request: ParseRequest):
    try:
        # Decode base64 image
        image_bytes = base64.b64decode(request.image_data)
        image = load_and_preprocess_image(image_bytes)
        
        # Process with model
        result = model.process_image(
            image=image,
            question=request.question
        )
        
        # Prepare response
        return ParseResponse(
            parsed_elements=result["parsed_elements"],
            box_coordinates=result["box_coordinates"],
            output_image=encode_output_image(image)
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)