Sanket17 commited on
Commit
566cff7
·
1 Parent(s): 95f8930

updaed files

Browse files
Files changed (3) hide show
  1. requirements.txt +2 -1
  2. src/api.py +14 -4
  3. src/model.py +3 -3
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch==2.0.0
4
  transformers
5
  pydantic
6
  python-multipart
7
- pillow
 
 
4
  transformers
5
  pydantic
6
  python-multipart
7
+ pillow
8
+ requests
src/api.py CHANGED
@@ -1,17 +1,27 @@
1
  from fastapi import FastAPI
2
  from .model import load_model
3
  from .schemas import Query
 
 
 
4
 
5
  app = FastAPI()
6
  processor, model = load_model()
7
 
8
  @app.post("/predict")
9
  async def predict(query: Query):
10
- inputs = processor(images=query.image_url, text=query.question, return_tensors="pt")
11
- outputs = model(**inputs)
12
- predicted_answer = processor.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
 
 
 
 
 
 
 
13
  return {"answer": predicted_answer}
14
 
15
  @app.get("/")
16
  async def root():
17
- return {"message": "OmniParser API is running. Use /predict endpoint for predictions."}
 
1
  from fastapi import FastAPI
2
  from .model import load_model
3
  from .schemas import Query
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
 
8
  app = FastAPI()
9
  processor, model = load_model()
10
 
11
  @app.post("/predict")
12
  async def predict(query: Query):
13
+ # Download and process the image
14
+ response = requests.get(query.image_url)
15
+ image = Image.open(BytesIO(response.content))
16
+
17
+ # Process the image and question
18
+ inputs = processor(images=image, text=query.question, return_tensors="pt")
19
+ outputs = model.generate(**inputs)
20
+
21
+ # Get the predicted answer
22
+ predicted_answer = processor.decode(outputs[0], skip_special_tokens=True)
23
  return {"answer": predicted_answer}
24
 
25
  @app.get("/")
26
  async def root():
27
+ return {"message": "Visual QA API is running. Use /predict endpoint for predictions."}
src/model.py CHANGED
@@ -1,6 +1,6 @@
1
- from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
2
 
3
  def load_model():
4
- processor = AutoProcessor.from_pretrained("microsoft/OmniParser")
5
- model = AutoModelForVisualQuestionAnswering.from_pretrained("microsoft/OmniParser")
6
  return processor, model
 
1
+ from transformers import BlipProcessor, BlipForQuestionAnswering
2
 
3
  def load_model():
4
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
5
+ model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
6
  return processor, model