OmniPar / src /api.py
Sanket17's picture
updaed files
566cff7
raw
history blame contribute delete
820 Bytes
from fastapi import FastAPI
from .model import load_model
from .schemas import Query
import requests
from PIL import Image
from io import BytesIO
app = FastAPI()
processor, model = load_model()
@app.post("/predict")
async def predict(query: Query):
# Download and process the image
response = requests.get(query.image_url)
image = Image.open(BytesIO(response.content))
# Process the image and question
inputs = processor(images=image, text=query.question, return_tensors="pt")
outputs = model.generate(**inputs)
# Get the predicted answer
predicted_answer = processor.decode(outputs[0], skip_special_tokens=True)
return {"answer": predicted_answer}
@app.get("/")
async def root():
return {"message": "Visual QA API is running. Use /predict endpoint for predictions."}