|
from fastapi import FastAPI |
|
from .model import load_model |
|
from .schemas import Query |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
app = FastAPI() |
|
processor, model = load_model() |
|
|
|
@app.post("/predict") |
|
async def predict(query: Query): |
|
|
|
response = requests.get(query.image_url) |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
|
|
inputs = processor(images=image, text=query.question, return_tensors="pt") |
|
outputs = model.generate(**inputs) |
|
|
|
|
|
predicted_answer = processor.decode(outputs[0], skip_special_tokens=True) |
|
return {"answer": predicted_answer} |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Visual QA API is running. Use /predict endpoint for predictions."} |