updaed files
Browse files- requirements.txt +2 -1
- src/api.py +14 -4
- src/model.py +3 -3
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ torch==2.0.0
|
|
4 |
transformers
|
5 |
pydantic
|
6 |
python-multipart
|
7 |
-
pillow
|
|
|
|
4 |
transformers
|
5 |
pydantic
|
6 |
python-multipart
|
7 |
+
pillow
|
8 |
+
requests
|
src/api.py
CHANGED
@@ -1,17 +1,27 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from .model import load_model
|
3 |
from .schemas import Query
|
|
|
|
|
|
|
4 |
|
5 |
app = FastAPI()
|
6 |
processor, model = load_model()
|
7 |
|
8 |
@app.post("/predict")
|
9 |
async def predict(query: Query):
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
return {"answer": predicted_answer}
|
14 |
|
15 |
@app.get("/")
|
16 |
async def root():
|
17 |
-
return {"message": "
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from .model import load_model
|
3 |
from .schemas import Query
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
|
8 |
app = FastAPI()
|
9 |
processor, model = load_model()
|
10 |
|
11 |
@app.post("/predict")
|
12 |
async def predict(query: Query):
|
13 |
+
# Download and process the image
|
14 |
+
response = requests.get(query.image_url)
|
15 |
+
image = Image.open(BytesIO(response.content))
|
16 |
+
|
17 |
+
# Process the image and question
|
18 |
+
inputs = processor(images=image, text=query.question, return_tensors="pt")
|
19 |
+
outputs = model.generate(**inputs)
|
20 |
+
|
21 |
+
# Get the predicted answer
|
22 |
+
predicted_answer = processor.decode(outputs[0], skip_special_tokens=True)
|
23 |
return {"answer": predicted_answer}
|
24 |
|
25 |
@app.get("/")
|
26 |
async def root():
|
27 |
+
return {"message": "Visual QA API is running. Use /predict endpoint for predictions."}
|
src/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from transformers import
|
2 |
|
3 |
def load_model():
|
4 |
-
processor =
|
5 |
-
model =
|
6 |
return processor, model
|
|
|
1 |
+
from transformers import BlipProcessor, BlipForQuestionAnswering
|
2 |
|
3 |
def load_model():
|
4 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
5 |
+
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
6 |
return processor, model
|