OmniPar / app.py
Sanket17's picture
Update app.py
fed4460 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
import torch
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
from PIL import Image
import io
import base64
app = FastAPI(
title="OmniParser API",
description="API for parsing GUI elements from images",
version="1.0.0"
)
# Model class
class OmniParser:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize processor and model
self.processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True,
cache_dir="/code/.cache"
)
self.model = AutoModelForVisualQuestionAnswering.from_pretrained(
"microsoft/OmniParser/icon_caption_florence",
trust_remote_code=True,
cache_dir="/code/.cache"
).to(self.device)
@torch.inference_mode()
def process_image(
self,
image: Image.Image,
question: str = "What elements do you see in this GUI?",
) -> Dict[str, Any]:
# Process image with the model
inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
# Decode the outputs
predicted_answer = self.processor.decode(
outputs.logits.argmax(-1)[0],
skip_special_tokens=True
)
return {
"parsed_elements": predicted_answer,
"box_coordinates": {} # Placeholder for future box detection implementation
}
# Initialize model
model = OmniParser()
# Request/Response models
class ParseRequest(BaseModel):
image_data: str = Field(..., description="Base64 encoded image data")
question: Optional[str] = Field(
default="What elements do you see in this GUI?",
description="Question to ask about the GUI"
)
class ParseResponse(BaseModel):
parsed_elements: str
box_coordinates: dict
output_image: Optional[str]
def load_and_preprocess_image(image_data: bytes) -> Optional[Image.Image]:
"""Load and preprocess image from bytes data."""
try:
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise ValueError(f"Failed to load image: {str(e)}")
def encode_output_image(image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@app.get("/")
async def root():
return {
"message": "OmniParser API is running",
"docs_url": "/docs"
}
@app.post("/parse", response_model=ParseResponse)
async def parse_image(request: ParseRequest):
try:
# Decode base64 image
image_bytes = base64.b64decode(request.image_data)
image = load_and_preprocess_image(image_bytes)
# Process with model
result = model.process_image(
image=image,
question=request.question
)
# Prepare response
return ParseResponse(
parsed_elements=result["parsed_elements"],
box_coordinates=result["box_coordinates"],
output_image=encode_output_image(image)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)