File size: 820 Bytes
107c705
 
 
566cff7
 
 
107c705
 
 
 
 
 
566cff7
 
 
 
 
 
 
 
 
 
107c705
 
 
 
566cff7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from fastapi import FastAPI
from .model import load_model
from .schemas import Query
import requests
from PIL import Image
from io import BytesIO

app = FastAPI()
processor, model = load_model()

@app.post("/predict")
async def predict(query: Query):
    # Download and process the image
    response = requests.get(query.image_url)
    image = Image.open(BytesIO(response.content))
    
    # Process the image and question
    inputs = processor(images=image, text=query.question, return_tensors="pt")
    outputs = model.generate(**inputs)
    
    # Get the predicted answer
    predicted_answer = processor.decode(outputs[0], skip_special_tokens=True)
    return {"answer": predicted_answer}

@app.get("/")
async def root():
    return {"message": "Visual QA API is running. Use /predict endpoint for predictions."}