|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel, Field |
|
from typing import Optional, Dict, Any |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering |
|
from PIL import Image |
|
import io |
|
import base64 |
|
|
|
app = FastAPI( |
|
title="OmniParser API", |
|
description="API for parsing GUI elements from images", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
class OmniParser: |
|
def __init__(self): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
"microsoft/Florence-2-base", |
|
trust_remote_code=True, |
|
cache_dir="/code/.cache" |
|
) |
|
self.model = AutoModelForVisualQuestionAnswering.from_pretrained( |
|
"microsoft/OmniParser/icon_caption_florence", |
|
trust_remote_code=True, |
|
cache_dir="/code/.cache" |
|
).to(self.device) |
|
|
|
@torch.inference_mode() |
|
def process_image( |
|
self, |
|
image: Image.Image, |
|
question: str = "What elements do you see in this GUI?", |
|
) -> Dict[str, Any]: |
|
|
|
inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) |
|
outputs = self.model(**inputs) |
|
|
|
|
|
predicted_answer = self.processor.decode( |
|
outputs.logits.argmax(-1)[0], |
|
skip_special_tokens=True |
|
) |
|
|
|
return { |
|
"parsed_elements": predicted_answer, |
|
"box_coordinates": {} |
|
} |
|
|
|
|
|
model = OmniParser() |
|
|
|
|
|
class ParseRequest(BaseModel): |
|
image_data: str = Field(..., description="Base64 encoded image data") |
|
question: Optional[str] = Field( |
|
default="What elements do you see in this GUI?", |
|
description="Question to ask about the GUI" |
|
) |
|
|
|
class ParseResponse(BaseModel): |
|
parsed_elements: str |
|
box_coordinates: dict |
|
output_image: Optional[str] |
|
|
|
def load_and_preprocess_image(image_data: bytes) -> Optional[Image.Image]: |
|
"""Load and preprocess image from bytes data.""" |
|
try: |
|
image = Image.open(io.BytesIO(image_data)) |
|
return image |
|
except Exception as e: |
|
raise ValueError(f"Failed to load image: {str(e)}") |
|
|
|
def encode_output_image(image: Image.Image) -> str: |
|
"""Encode PIL Image to base64 string.""" |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"message": "OmniParser API is running", |
|
"docs_url": "/docs" |
|
} |
|
|
|
@app.post("/parse", response_model=ParseResponse) |
|
async def parse_image(request: ParseRequest): |
|
try: |
|
|
|
image_bytes = base64.b64decode(request.image_data) |
|
image = load_and_preprocess_image(image_bytes) |
|
|
|
|
|
result = model.process_image( |
|
image=image, |
|
question=request.question |
|
) |
|
|
|
|
|
return ParseResponse( |
|
parsed_elements=result["parsed_elements"], |
|
box_coordinates=result["box_coordinates"], |
|
output_image=encode_output_image(image) |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |