PAVULURI KIRAN commited on
Commit
7818ba9
·
1 Parent(s): f6f25cf

Updated FastAPI app and requirements

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -7,7 +7,6 @@ import base64
7
  import os
8
  import logging
9
  from huggingface_hub import login
10
- from torchvision import transforms
11
 
12
  # Enable logging
13
  logging.basicConfig(level=logging.INFO)
@@ -22,16 +21,22 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  os.environ["HF_HOME"] = "/tmp/huggingface"
23
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
24
 
25
- # Login to Hugging Face
26
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
27
- if HF_TOKEN:
 
 
 
 
 
28
  login(HF_TOKEN)
29
- else:
30
- logger.warning("Hugging Face token not found. Set HUGGINGFACE_TOKEN.")
 
31
 
32
  # Configure Quantization
33
  quantization_config = BitsAndBytesConfig(
34
- load_in_4bit=True,
35
  bnb_4bit_compute_dtype=torch.float16,
36
  bnb_4bit_use_double_quant=True,
37
  )
@@ -57,13 +62,6 @@ except Exception as e:
57
  # Allowed Formats
58
  ALLOWED_FORMATS = {"jpeg", "jpg", "png", "bmp", "tiff"}
59
 
60
- def preprocess_image(image: Image.Image):
61
- transform = transforms.Compose([
62
- transforms.Resize((512, 512)),
63
- transforms.ToTensor(),
64
- ])
65
- return transform(image).unsqueeze(0)
66
-
67
  @app.post("/predict/")
68
  async def predict(file: UploadFile = File(...)):
69
  try:
@@ -74,7 +72,6 @@ async def predict(file: UploadFile = File(...)):
74
  # Read Image
75
  image_bytes = await file.read()
76
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
77
- image_tensor = preprocess_image(image)
78
 
79
  # Convert Image to Base64
80
  buffered = io.BytesIO()
@@ -84,7 +81,7 @@ async def predict(file: UploadFile = File(...)):
84
  # Validation Step
85
  validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'."
86
  validation_inputs = processor(
87
- text=validation_prompt, images=image_tensor, return_tensors="pt"
88
  ).to(DEVICE)
89
 
90
  with torch.no_grad():
@@ -109,7 +106,7 @@ async def predict(file: UploadFile = File(...)):
109
  Recommendations:
110
  • [Follow-up Actions]
111
  """
112
- analysis_inputs = processor(text=analysis_prompt, images=image_tensor, return_tensors="pt").to(DEVICE)
113
 
114
  with torch.no_grad():
115
  analysis_output = model.generate(
 
7
  import os
8
  import logging
9
  from huggingface_hub import login
 
10
 
11
  # Enable logging
12
  logging.basicConfig(level=logging.INFO)
 
21
  os.environ["HF_HOME"] = "/tmp/huggingface"
22
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
23
 
24
+ # Ensure Hugging Face Token is Set
25
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
26
+ if not HF_TOKEN:
27
+ logger.error("Hugging Face token not found! Set HUGGINGFACE_TOKEN in the environment.")
28
+ raise RuntimeError("Hugging Face token missing. Set it in your environment.")
29
+
30
+ # Login to Hugging Face
31
+ try:
32
  login(HF_TOKEN)
33
+ except Exception as e:
34
+ logger.error(f"Failed to authenticate Hugging Face token: {e}")
35
+ raise RuntimeError("Authentication with Hugging Face failed.")
36
 
37
  # Configure Quantization
38
  quantization_config = BitsAndBytesConfig(
39
+ load_in_4bit=True, # Change to load_in_8bit=True if 4-bit fails
40
  bnb_4bit_compute_dtype=torch.float16,
41
  bnb_4bit_use_double_quant=True,
42
  )
 
62
  # Allowed Formats
63
  ALLOWED_FORMATS = {"jpeg", "jpg", "png", "bmp", "tiff"}
64
 
 
 
 
 
 
 
 
65
  @app.post("/predict/")
66
  async def predict(file: UploadFile = File(...)):
67
  try:
 
72
  # Read Image
73
  image_bytes = await file.read()
74
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
75
 
76
  # Convert Image to Base64
77
  buffered = io.BytesIO()
 
81
  # Validation Step
82
  validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'."
83
  validation_inputs = processor(
84
+ text=validation_prompt, images=image, return_tensors="pt"
85
  ).to(DEVICE)
86
 
87
  with torch.no_grad():
 
106
  Recommendations:
107
  • [Follow-up Actions]
108
  """
109
+ analysis_inputs = processor(text=analysis_prompt, images=image, return_tensors="pt").to(DEVICE)
110
 
111
  with torch.no_grad():
112
  analysis_output = model.generate(