from fastapi import FastAPI from .model import load_model from .schemas import Query import requests from PIL import Image from io import BytesIO app = FastAPI() processor, model = load_model() @app.post("/predict") async def predict(query: Query): # Download and process the image response = requests.get(query.image_url) image = Image.open(BytesIO(response.content)) # Process the image and question inputs = processor(images=image, text=query.question, return_tensors="pt") outputs = model.generate(**inputs) # Get the predicted answer predicted_answer = processor.decode(outputs[0], skip_special_tokens=True) return {"answer": predicted_answer} @app.get("/") async def root(): return {"message": "Visual QA API is running. Use /predict endpoint for predictions."}