File size: 820 Bytes
107c705 566cff7 107c705 566cff7 107c705 566cff7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from fastapi import FastAPI
from .model import load_model
from .schemas import Query
import requests
from PIL import Image
from io import BytesIO
app = FastAPI()
processor, model = load_model()
@app.post("/predict")
async def predict(query: Query):
# Download and process the image
response = requests.get(query.image_url)
image = Image.open(BytesIO(response.content))
# Process the image and question
inputs = processor(images=image, text=query.question, return_tensors="pt")
outputs = model.generate(**inputs)
# Get the predicted answer
predicted_answer = processor.decode(outputs[0], skip_special_tokens=True)
return {"answer": predicted_answer}
@app.get("/")
async def root():
return {"message": "Visual QA API is running. Use /predict endpoint for predictions."} |